diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d86b869b..777ea7dbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,8 +53,19 @@ repos: args: - --config-file=pyproject.toml additional_dependencies: - - pydantic~=2.0 + - "aio-eapi==0.3.0" + - "click==8.1.3" + - "click-help-colors==0.9.1" + - "cvprac>=1.2.0" + - "netaddr>=0.8.0" + - "pydantic~=2.0" + - "PyYAML>=6.0" + - "requests" + - "rich>=12.5.1" + - "scp" + - "asyncssh==2.13.1" + - "Jinja2>=3.1.2" - types-PyYAML - types-paramiko - types-requests - files: ^(anta|scripts|tests)/ + files: ^(anta|tests)/ diff --git a/anta/cli/exec/utils.py b/anta/cli/exec/utils.py index 6d0afe29b..edf352459 100644 --- a/anta/cli/exec/utils.py +++ b/anta/cli/exec/utils.py @@ -66,10 +66,10 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex logger.error(f"Could not collect commands on device {dev.name}: {exc_to_str(c.failed)}") return if c.ofmt == "json": - outfile = outdir / (command + ".json") + outfile = outdir / f"{command}.json" content = json.dumps(c.json_output, indent=2) elif c.ofmt == "text": - outfile = outdir / (command + ".log") + outfile = outdir / f"{command}.log" content = c.text_output with outfile.open(mode="w", encoding="UTF-8") as f: f.write(content) @@ -88,7 +88,7 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex if __DEBUG__: logger.exception(message, exc_info=r) else: - logger.error(message + f": {exc_to_str(r)}") + logger.error(f"{message}: {exc_to_str(r)}") async def collect_scheduled_show_tech(inv: AntaInventory, root_dir: Path, configure: bool, tags: Optional[List[str]] = None, latest: Optional[int] = None) -> None: @@ -150,7 +150,7 @@ async def collect(device: AntaDevice) -> None: if __DEBUG__: logger.exception(message) else: - logger.error(message + f": {exc_to_str(e)}") + logger.error(f"{message}: {exc_to_str(e)}") logger.info("Connecting to devices...") await inv.connect_inventory() diff --git a/anta/cli/nrfu/commands.py b/anta/cli/nrfu/commands.py index 289765961..d53be8f17 100644 --- a/anta/cli/nrfu/commands.py +++ b/anta/cli/nrfu/commands.py @@ -13,10 +13,11 @@ import click from anta.cli.utils import parse_tags, return_code +from anta.models import AntaTest from anta.result_manager import ResultManager from anta.runner import main -from .utils import print_jinja, print_json, print_settings, print_table, print_text +from .utils import anta_progress_bar, print_jinja, print_json, print_settings, print_table, print_text logger = logging.getLogger(__name__) @@ -30,7 +31,8 @@ def table(ctx: click.Context, tags: Optional[List[str]], device: Optional[str], """ANTA command to check network states with table result""" print_settings(ctx) results = ResultManager() - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) + with anta_progress_bar() as AntaTest.progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) print_table(results=results, device=device, test=test) # TODO make a util method to avoid repeating the same three line @@ -54,7 +56,8 @@ def json(ctx: click.Context, tags: Optional[List[str]], output: Optional[pathlib """ANTA command to check network state with JSON result""" print_settings(ctx) results = ResultManager() - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) + with anta_progress_bar() as AntaTest.progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) print_json(results=results, output=output) ignore_status = ctx.obj["ignore_status"] @@ -71,7 +74,8 @@ def text(ctx: click.Context, tags: Optional[List[str]], search: Optional[str], s """ANTA command to check network states with text result""" print_settings(ctx) results = ResultManager() - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) + with anta_progress_bar() as AntaTest.progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) print_text(results=results, search=search, skip_error=skip_error) ignore_status = ctx.obj["ignore_status"] @@ -102,7 +106,8 @@ def tpl_report(ctx: click.Context, tags: Optional[List[str]], template: pathlib. """ANTA command to check network state with templated report""" print_settings(ctx, template, output) results = ResultManager() - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) + with anta_progress_bar() as AntaTest.progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) print_jinja(results=results, template=template, output=output) ignore_status = ctx.obj["ignore_status"] diff --git a/anta/cli/nrfu/utils.py b/anta/cli/nrfu/utils.py index de821e19e..9143ca4af 100644 --- a/anta/cli/nrfu/utils.py +++ b/anta/cli/nrfu/utils.py @@ -15,6 +15,7 @@ import rich from rich.panel import Panel from rich.pretty import pprint +from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from anta.cli.console import console from anta.reporter import ReportJinja, ReportTable @@ -80,3 +81,44 @@ def print_jinja(results: ResultManager, template: pathlib.Path, output: Optional if output is not None: with open(output, "w", encoding="utf-8") as file: file.write(report) + + +# Adding our own ANTA spinner - overriding rich SPINNERS for our own +# so ignore warning for redefinition +rich.spinner.SPINNERS = { # type: ignore[attr-defined] # noqa: F811 + "anta": { + "interval": 150, + "frames": [ + "( 🐜)", + "( 🐜 )", + "( 🐜 )", + "( 🐜 )", + "( 🐜 )", + "(🐜 )", + "(🐌 )", + "( 🐌 )", + "( 🐌 )", + "( 🐌 )", + "( 🐌 )", + "( 🐌)", + ], + } +} + + +def anta_progress_bar() -> Progress: + """ + Return a customized Progress for progress bar + """ + return Progress( + SpinnerColumn("anta"), + TextColumn("•"), + TextColumn("{task.description}[progress.percentage]{task.percentage:>3.0f}%"), + BarColumn(bar_width=None), + MofNCompleteColumn(), + TextColumn("•"), + TimeElapsedColumn(), + TextColumn("•"), + TimeRemainingColumn(), + expand=True, + ) diff --git a/anta/decorators.py b/anta/decorators.py index da0a144a5..a22313f33 100644 --- a/anta/decorators.py +++ b/anta/decorators.py @@ -4,7 +4,7 @@ from functools import wraps from typing import Any, Callable, Dict, List, TypeVar, cast -from anta.models import AntaCommand +from anta.models import AntaCommand, AntaTest from anta.result_manager.models import TestResult from anta.tools.misc import exc_to_str @@ -35,10 +35,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: anta_test = args[0] if anta_test.result.result != "unset": + AntaTest.update_progress() return anta_test.result if anta_test.device.hw_model in platforms: anta_test.result.is_skipped(f"{anta_test.__class__.__name__} test is not supported on {anta_test.device.hw_model}.") + AntaTest.update_progress() return anta_test.result return await function(*args, **kwargs) @@ -71,6 +73,7 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: anta_test = args[0] if anta_test.result.result != "unset": + AntaTest.update_progress() return anta_test.result if family == "ipv4": @@ -92,10 +95,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: return anta_test.result if "vrfs" not in command.json_output: anta_test.result.is_skipped(f"no BGP configuration for {family} on this device") + AntaTest.update_progress() return anta_test.result if len(bgp_vrfs := command.json_output["vrfs"]) == 0 or len(bgp_vrfs["default"]["peers"]) == 0: # No VRF anta_test.result.is_skipped(f"no {family} peer on this device") + AntaTest.update_progress() return anta_test.result return await function(*args, **kwargs) diff --git a/anta/inventory/__init__.py b/anta/inventory/__init__.py index dd414dcdb..617651bb6 100644 --- a/anta/inventory/__init__.py +++ b/anta/inventory/__init__.py @@ -170,4 +170,4 @@ async def connect_inventory(self) -> None: if __DEBUG__: logger.exception(message, exc_info=r) else: - logger.error(message + f": {exc_to_str(r)}") + logger.error(f"{message}: {exc_to_str(r)}") diff --git a/anta/models.py b/anta/models.py index e57821c1a..7e5b22f0e 100644 --- a/anta/models.py +++ b/anta/models.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Coroutine, Dict, List, Literal, Optional, TypeVar, Union from pydantic import BaseModel, ConfigDict, conint +from rich.progress import Progress, TaskID from anta import __DEBUG__ from anta.result_manager.models import TestResult @@ -151,6 +152,8 @@ class AntaTest(ABC): commands: ClassVar[list[AntaCommand]] # TODO - today we support only one template per Test template: ClassVar[AntaTemplate] + progress: Optional[Progress] = None + nrfu_task: Optional[TaskID] = None # Optional class attributes test_filters: ClassVar[list[AntaTestFilter]] @@ -164,6 +167,8 @@ def __init__( labels: list[str] | None = None, ): """Class constructor""" + # Accept 6 input arguments + # pylint: disable=R0913 self.logger: logging.Logger = logging.getLogger(f"{self.__module__}.{self.__class__.__name__}") self.device: AntaDevice = device self.result: TestResult = TestResult(name=device.name, test=self.name, test_category=self.categories, test_description=self.description) @@ -204,11 +209,7 @@ def all_data_collected(self) -> bool: def get_failed_commands(self) -> List[AntaCommand]: """returns a list of all the commands that have a populated failed field""" - errors = [] - for command in self.instance_commands: - if command.failed is not None: - errors.append(command) - return errors + return [command for command in self.instance_commands if command.failed is not None] def __init_subclass__(cls) -> None: """ @@ -233,7 +234,7 @@ async def collect(self) -> None: if __DEBUG__: self.logger.exception(message) else: - self.logger.error(message + f": {exc_to_str(e)}") + self.logger.error(f"{message}: {exc_to_str(e)}") self.result.is_error(exc_to_str(e)) @staticmethod @@ -286,12 +287,22 @@ async def wrapper( if __DEBUG__: self.logger.exception(message) else: - self.logger.error(message + f": {exc_to_str(e)}") + self.logger.error(f"{message}: {exc_to_str(e)}") self.result.is_error(exc_to_str(e)) + + AntaTest.update_progress() return self.result return wrapper + @classmethod + def update_progress(cls) -> None: + """ + Update progress bar for all AntaTest objects if it exists + """ + if cls.progress and (cls.nrfu_task is not None): + cls.progress.update(cls.nrfu_task, advance=1) + @abstractmethod def test(self) -> Coroutine[Any, Any, TestResult]: """ diff --git a/anta/runner.py b/anta/runner.py index 488fb201c..18be2d5d2 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -9,6 +9,7 @@ from anta import __DEBUG__ from anta.inventory import AntaInventory +from anta.models import AntaTest from anta.result_manager import ResultManager from anta.result_manager.models import TestResult from anta.tools.misc import exc_to_str @@ -45,6 +46,8 @@ async def main( Returns: any: List of results. """ + # Accept 6 arguments here + # pylint: disable=R0913 await inventory.connect_inventory() @@ -64,7 +67,10 @@ async def main( if __DEBUG__: logger.exception(message) else: - logger.error(message + f": {exc_to_str(e)}") + logger.error(f"{message}: {exc_to_str(e)}") + + if AntaTest.progress is not None: + AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coros)) logger.info("Running ANTA tests...") res = await asyncio.gather(*coros, return_exceptions=True) @@ -74,6 +80,6 @@ async def main( if __DEBUG__: logger.exception(message, exc_info=r) else: - logger.error(message + f": {exc_to_str(r)}") + logger.error(f"{message}: {exc_to_str(r)}") res.remove(r) manager.add_test_results(res) diff --git a/pylintrc b/pylintrc index ddbe1b88c..a02d42c67 100644 --- a/pylintrc +++ b/pylintrc @@ -2,7 +2,6 @@ disable= invalid-name, logging-fstring-interpolation, - logging-not-lazy, fixme [BASIC]