From d53e86aa592ce3a743e0baea8655dd771d0affc8 Mon Sep 17 00:00:00 2001 From: gmuloc Date: Wed, 5 Jul 2023 14:05:45 +0200 Subject: [PATCH 1/8] Feat(anta.cli): Add Progress Bar for NRFU --- .pre-commit-config.yaml | 15 +++++++++++-- anta/cli/nrfu/commands.py | 14 +++++++----- anta/cli/nrfu/utils.py | 46 +++++++++++++++++++++++++++++++++++++++ anta/models.py | 22 ++++++++++++------- anta/runner.py | 14 +++++++++--- pylintrc | 1 - 6 files changed, 93 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d86b869b..777ea7dbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,8 +53,19 @@ repos: args: - --config-file=pyproject.toml additional_dependencies: - - pydantic~=2.0 + - "aio-eapi==0.3.0" + - "click==8.1.3" + - "click-help-colors==0.9.1" + - "cvprac>=1.2.0" + - "netaddr>=0.8.0" + - "pydantic~=2.0" + - "PyYAML>=6.0" + - "requests" + - "rich>=12.5.1" + - "scp" + - "asyncssh==2.13.1" + - "Jinja2>=3.1.2" - types-PyYAML - types-paramiko - types-requests - files: ^(anta|scripts|tests)/ + files: ^(anta|tests)/ diff --git a/anta/cli/nrfu/commands.py b/anta/cli/nrfu/commands.py index 289765961..00aeb4b29 100644 --- a/anta/cli/nrfu/commands.py +++ b/anta/cli/nrfu/commands.py @@ -16,7 +16,7 @@ from anta.result_manager import ResultManager from anta.runner import main -from .utils import print_jinja, print_json, print_settings, print_table, print_text +from .utils import anta_progress_bar, print_jinja, print_json, print_settings, print_table, print_text logger = logging.getLogger(__name__) @@ -30,7 +30,8 @@ def table(ctx: click.Context, tags: Optional[List[str]], device: Optional[str], """ANTA command to check network states with table result""" print_settings(ctx) results = ResultManager() - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) + with anta_progress_bar() as progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress)) print_table(results=results, device=device, test=test) # TODO make a util method to avoid repeating the same three line @@ -54,7 +55,8 @@ def json(ctx: click.Context, tags: Optional[List[str]], output: Optional[pathlib """ANTA command to check network state with JSON result""" print_settings(ctx) results = ResultManager() - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) + with anta_progress_bar() as progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress)) print_json(results=results, output=output) ignore_status = ctx.obj["ignore_status"] @@ -71,7 +73,8 @@ def text(ctx: click.Context, tags: Optional[List[str]], search: Optional[str], s """ANTA command to check network states with text result""" print_settings(ctx) results = ResultManager() - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) + with anta_progress_bar() as progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress)) print_text(results=results, search=search, skip_error=skip_error) ignore_status = ctx.obj["ignore_status"] @@ -102,7 +105,8 @@ def tpl_report(ctx: click.Context, tags: Optional[List[str]], template: pathlib. """ANTA command to check network state with templated report""" print_settings(ctx, template, output) results = ResultManager() - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) + with anta_progress_bar() as progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress)) print_jinja(results=results, template=template, output=output) ignore_status = ctx.obj["ignore_status"] diff --git a/anta/cli/nrfu/utils.py b/anta/cli/nrfu/utils.py index de821e19e..c8f047539 100644 --- a/anta/cli/nrfu/utils.py +++ b/anta/cli/nrfu/utils.py @@ -15,6 +15,11 @@ import rich from rich.panel import Panel from rich.pretty import pprint +from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn + +# Adding our own ANTA spinner +# pylint: disable-next=W0611 +from rich.spinner import SPINNERS # type: ignore[attr-defined] from anta.cli.console import console from anta.reporter import ReportJinja, ReportTable @@ -80,3 +85,44 @@ def print_jinja(results: ResultManager, template: pathlib.Path, output: Optional if output is not None: with open(output, "w", encoding="utf-8") as file: file.write(report) + + +def anta_progress_bar() -> Progress: + """ + Return a customized Progress for progress bar + """ + return Progress( + SpinnerColumn("anta"), + TextColumn("•"), + TextColumn("{task.description}[progress.percentage]{task.percentage:>3.0f}%"), + BarColumn(bar_width=None), + MofNCompleteColumn(), + TextColumn("•"), + TimeElapsedColumn(), + TextColumn("•"), + TimeRemainingColumn(), + expand=True, + ) + + +# Overriding rich SPINNERS for our own +# so ignore warning for redefinition +SPINNERS = { # noqa: F811 + "anta": { + "interval": 150, + "frames": [ + "( 🐜)", + "( 🐜 )", + "( 🐜 )", + "( 🐜 )", + "( 🐜 )", + "(🐜 )", + "(🐌 )", + "( 🐌 )", + "( 🐌 )", + "( 🐌 )", + "( 🐌 )", + "( 🐌)", + ], + } +} diff --git a/anta/models.py b/anta/models.py index e57821c1a..da24d5158 100644 --- a/anta/models.py +++ b/anta/models.py @@ -8,9 +8,10 @@ from abc import ABC, abstractmethod from copy import deepcopy from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Coroutine, Dict, List, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Coroutine, Dict, List, Literal, Optional, TypeVar, Union, cast from pydantic import BaseModel, ConfigDict, conint +from rich.progress import Progress, TaskID from anta import __DEBUG__ from anta.result_manager.models import TestResult @@ -162,13 +163,17 @@ def __init__( # TODO document very well the order of eos_data eos_data: list[dict[Any, Any] | str] | None = None, labels: list[str] | None = None, + progress: Optional[Progress] = None, ): """Class constructor""" + # Accept 6 input arguments + # pylint: disable=R0913 self.logger: logging.Logger = logging.getLogger(f"{self.__module__}.{self.__class__.__name__}") self.device: AntaDevice = device self.result: TestResult = TestResult(name=device.name, test=self.name, test_category=self.categories, test_description=self.description) self.labels: List[str] = labels or [] self.instance_commands: List[AntaCommand] = [] + self.progress = progress # TODO - check optimization for deepcopy # Generating instance_commands from list of commands and template @@ -204,11 +209,7 @@ def all_data_collected(self) -> bool: def get_failed_commands(self) -> List[AntaCommand]: """returns a list of all the commands that have a populated failed field""" - errors = [] - for command in self.instance_commands: - if command.failed is not None: - errors.append(command) - return errors + return [command for command in self.instance_commands if command.failed is not None] def __init_subclass__(cls) -> None: """ @@ -233,7 +234,7 @@ async def collect(self) -> None: if __DEBUG__: self.logger.exception(message) else: - self.logger.error(message + f": {exc_to_str(e)}") + self.logger.error(f"{message}: {exc_to_str(e)}") self.result.is_error(exc_to_str(e)) @staticmethod @@ -286,8 +287,13 @@ async def wrapper( if __DEBUG__: self.logger.exception(message) else: - self.logger.error(message + f": {exc_to_str(e)}") + self.logger.error(f"{message}: {exc_to_str(e)}") self.result.is_error(exc_to_str(e)) + if self.progress: + # TODO this is hacky because we only have one task.. + # Should be id 0 - casting for mypy + nrfu_task: TaskID = cast(TaskID, 0) + self.progress.update(nrfu_task, advance=1) return self.result return wrapper diff --git a/anta/runner.py b/anta/runner.py index 488fb201c..792a50daa 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -7,6 +7,8 @@ import logging from typing import Any, Callable, Dict, List, Optional, Tuple +from rich.progress import Progress + from anta import __DEBUG__ from anta.inventory import AntaInventory from anta.result_manager import ResultManager @@ -25,6 +27,7 @@ async def main( tests: List[Tuple[Callable[..., TestResult], Dict[Any, Any]]], tags: Optional[List[str]] = None, established_only: bool = True, + progress: Optional[Progress] = None, ) -> None: """ Main coroutine to run ANTA. @@ -45,6 +48,8 @@ async def main( Returns: any: List of results. """ + # Accept 6 arguments here + # pylint: disable=R0913 await inventory.connect_inventory() @@ -57,14 +62,17 @@ async def main( template_params = test[1].get(TEST_TPL_PARAMS) try: # Instantiate AntaTest object - test_instance = test[0](device=device, template_params=template_params) + test_instance = test[0](device=device, template_params=template_params, progress=progress) coros.append(test_instance.test(eos_data=None, **test_params)) except Exception as e: # pylint: disable=broad-exception-caught message = "Error when creating ANTA tests" if __DEBUG__: logger.exception(message) else: - logger.error(message + f": {exc_to_str(e)}") + logger.error(f"{message}: {exc_to_str(e)}") + + if progress is not None: + progress.add_task("Running NRFU Tests...", total=len(coros)) logger.info("Running ANTA tests...") res = await asyncio.gather(*coros, return_exceptions=True) @@ -74,6 +82,6 @@ async def main( if __DEBUG__: logger.exception(message, exc_info=r) else: - logger.error(message + f": {exc_to_str(r)}") + logger.error(f"{message}: {exc_to_str(r)}") res.remove(r) manager.add_test_results(res) diff --git a/pylintrc b/pylintrc index ddbe1b88c..a02d42c67 100644 --- a/pylintrc +++ b/pylintrc @@ -2,7 +2,6 @@ disable= invalid-name, logging-fstring-interpolation, - logging-not-lazy, fixme [BASIC] From 56fdc6c5cbc4e095055b6d9393d08732d51056e5 Mon Sep 17 00:00:00 2001 From: gmuloc Date: Thu, 6 Jul 2023 18:01:43 +0200 Subject: [PATCH 2/8] WIP --- anta/cli/nrfu/utils.py | 44 +++++++++++++++++++----------------------- anta/decorators.py | 5 +++++ anta/models.py | 18 ++++++++++++----- 3 files changed, 38 insertions(+), 29 deletions(-) diff --git a/anta/cli/nrfu/utils.py b/anta/cli/nrfu/utils.py index c8f047539..9143ca4af 100644 --- a/anta/cli/nrfu/utils.py +++ b/anta/cli/nrfu/utils.py @@ -17,10 +17,6 @@ from rich.pretty import pprint from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn -# Adding our own ANTA spinner -# pylint: disable-next=W0611 -from rich.spinner import SPINNERS # type: ignore[attr-defined] - from anta.cli.console import console from anta.reporter import ReportJinja, ReportTable from anta.result_manager import ResultManager @@ -87,27 +83,9 @@ def print_jinja(results: ResultManager, template: pathlib.Path, output: Optional file.write(report) -def anta_progress_bar() -> Progress: - """ - Return a customized Progress for progress bar - """ - return Progress( - SpinnerColumn("anta"), - TextColumn("•"), - TextColumn("{task.description}[progress.percentage]{task.percentage:>3.0f}%"), - BarColumn(bar_width=None), - MofNCompleteColumn(), - TextColumn("•"), - TimeElapsedColumn(), - TextColumn("•"), - TimeRemainingColumn(), - expand=True, - ) - - -# Overriding rich SPINNERS for our own +# Adding our own ANTA spinner - overriding rich SPINNERS for our own # so ignore warning for redefinition -SPINNERS = { # noqa: F811 +rich.spinner.SPINNERS = { # type: ignore[attr-defined] # noqa: F811 "anta": { "interval": 150, "frames": [ @@ -126,3 +104,21 @@ def anta_progress_bar() -> Progress: ], } } + + +def anta_progress_bar() -> Progress: + """ + Return a customized Progress for progress bar + """ + return Progress( + SpinnerColumn("anta"), + TextColumn("•"), + TextColumn("{task.description}[progress.percentage]{task.percentage:>3.0f}%"), + BarColumn(bar_width=None), + MofNCompleteColumn(), + TextColumn("•"), + TimeElapsedColumn(), + TextColumn("•"), + TimeRemainingColumn(), + expand=True, + ) diff --git a/anta/decorators.py b/anta/decorators.py index da0a144a5..bd9d118d9 100644 --- a/anta/decorators.py +++ b/anta/decorators.py @@ -35,10 +35,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: anta_test = args[0] if anta_test.result.result != "unset": + anta_test.update_progres() return anta_test.result if anta_test.device.hw_model in platforms: anta_test.result.is_skipped(f"{anta_test.__class__.__name__} test is not supported on {anta_test.device.hw_model}.") + anta_test.update_progres() return anta_test.result return await function(*args, **kwargs) @@ -71,6 +73,7 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: anta_test = args[0] if anta_test.result.result != "unset": + anta_test.update_progres() return anta_test.result if family == "ipv4": @@ -92,10 +95,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: return anta_test.result if "vrfs" not in command.json_output: anta_test.result.is_skipped(f"no BGP configuration for {family} on this device") + anta_test.update_progres() return anta_test.result if len(bgp_vrfs := command.json_output["vrfs"]) == 0 or len(bgp_vrfs["default"]["peers"]) == 0: # No VRF anta_test.result.is_skipped(f"no {family} peer on this device") + anta_test.update_progres() return anta_test.result return await function(*args, **kwargs) diff --git a/anta/models.py b/anta/models.py index da24d5158..308f02230 100644 --- a/anta/models.py +++ b/anta/models.py @@ -289,15 +289,23 @@ async def wrapper( else: self.logger.error(f"{message}: {exc_to_str(e)}") self.result.is_error(exc_to_str(e)) - if self.progress: - # TODO this is hacky because we only have one task.. - # Should be id 0 - casting for mypy - nrfu_task: TaskID = cast(TaskID, 0) - self.progress.update(nrfu_task, advance=1) + + self.update_progress() return self.result return wrapper + def update_progress(self) -> None: + """ + Update progress bar if it exists + """ + if self.progress: + # TODO this is hacky because we only have one task.. + # Should be id 0 - casting for mypy + nrfu_task: TaskID = cast(TaskID, 0) + self.progress.update(nrfu_task, advance=1) + + @abstractmethod def test(self) -> Coroutine[Any, Any, TestResult]: """ From ca65b2256b2018bf1ed57d69a4425f8939e3db26 Mon Sep 17 00:00:00 2001 From: gmuloc Date: Thu, 6 Jul 2023 18:10:54 +0200 Subject: [PATCH 3/8] CI: They be linting --- anta/cli/exec/utils.py | 8 ++++---- anta/inventory/__init__.py | 2 +- anta/models.py | 1 - 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/anta/cli/exec/utils.py b/anta/cli/exec/utils.py index 6d0afe29b..edf352459 100644 --- a/anta/cli/exec/utils.py +++ b/anta/cli/exec/utils.py @@ -66,10 +66,10 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex logger.error(f"Could not collect commands on device {dev.name}: {exc_to_str(c.failed)}") return if c.ofmt == "json": - outfile = outdir / (command + ".json") + outfile = outdir / f"{command}.json" content = json.dumps(c.json_output, indent=2) elif c.ofmt == "text": - outfile = outdir / (command + ".log") + outfile = outdir / f"{command}.log" content = c.text_output with outfile.open(mode="w", encoding="UTF-8") as f: f.write(content) @@ -88,7 +88,7 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex if __DEBUG__: logger.exception(message, exc_info=r) else: - logger.error(message + f": {exc_to_str(r)}") + logger.error(f"{message}: {exc_to_str(r)}") async def collect_scheduled_show_tech(inv: AntaInventory, root_dir: Path, configure: bool, tags: Optional[List[str]] = None, latest: Optional[int] = None) -> None: @@ -150,7 +150,7 @@ async def collect(device: AntaDevice) -> None: if __DEBUG__: logger.exception(message) else: - logger.error(message + f": {exc_to_str(e)}") + logger.error(f"{message}: {exc_to_str(e)}") logger.info("Connecting to devices...") await inv.connect_inventory() diff --git a/anta/inventory/__init__.py b/anta/inventory/__init__.py index dd414dcdb..617651bb6 100644 --- a/anta/inventory/__init__.py +++ b/anta/inventory/__init__.py @@ -170,4 +170,4 @@ async def connect_inventory(self) -> None: if __DEBUG__: logger.exception(message, exc_info=r) else: - logger.error(message + f": {exc_to_str(r)}") + logger.error(f"{message}: {exc_to_str(r)}") diff --git a/anta/models.py b/anta/models.py index 308f02230..14b34513f 100644 --- a/anta/models.py +++ b/anta/models.py @@ -305,7 +305,6 @@ def update_progress(self) -> None: nrfu_task: TaskID = cast(TaskID, 0) self.progress.update(nrfu_task, advance=1) - @abstractmethod def test(self) -> Coroutine[Any, Any, TestResult]: """ From 063dc99822deb6c05d3c98693a7de399b4db3bf1 Mon Sep 17 00:00:00 2001 From: gmuloc Date: Fri, 7 Jul 2023 09:11:37 +0200 Subject: [PATCH 4/8] Fix: ProgresS with two 's' --- anta/decorators.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/anta/decorators.py b/anta/decorators.py index bd9d118d9..a75514c83 100644 --- a/anta/decorators.py +++ b/anta/decorators.py @@ -35,12 +35,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: anta_test = args[0] if anta_test.result.result != "unset": - anta_test.update_progres() + anta_test.update_progress() return anta_test.result if anta_test.device.hw_model in platforms: anta_test.result.is_skipped(f"{anta_test.__class__.__name__} test is not supported on {anta_test.device.hw_model}.") - anta_test.update_progres() + anta_test.update_progress() return anta_test.result return await function(*args, **kwargs) @@ -73,7 +73,7 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: anta_test = args[0] if anta_test.result.result != "unset": - anta_test.update_progres() + anta_test.update_progress() return anta_test.result if family == "ipv4": @@ -95,12 +95,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: return anta_test.result if "vrfs" not in command.json_output: anta_test.result.is_skipped(f"no BGP configuration for {family} on this device") - anta_test.update_progres() + anta_test.update_progress() return anta_test.result if len(bgp_vrfs := command.json_output["vrfs"]) == 0 or len(bgp_vrfs["default"]["peers"]) == 0: # No VRF anta_test.result.is_skipped(f"no {family} peer on this device") - anta_test.update_progres() + anta_test.update_progress() return anta_test.result return await function(*args, **kwargs) From b6e02fa8c43759d1eb8405c45d827f7edc142018 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthieu=20T=C3=A2che?= Date: Fri, 7 Jul 2023 10:03:12 +0200 Subject: [PATCH 5/8] make progress an AntaTest class attribute --- anta/cli/nrfu/commands.py | 17 +++++++++-------- anta/decorators.py | 14 ++++++++------ anta/models.py | 22 ++++++++++------------ anta/runner.py | 10 +++++----- 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/anta/cli/nrfu/commands.py b/anta/cli/nrfu/commands.py index 00aeb4b29..09d6fc1e1 100644 --- a/anta/cli/nrfu/commands.py +++ b/anta/cli/nrfu/commands.py @@ -14,6 +14,7 @@ from anta.cli.utils import parse_tags, return_code from anta.result_manager import ResultManager +from anta.models import AntaTest from anta.runner import main from .utils import anta_progress_bar, print_jinja, print_json, print_settings, print_table, print_text @@ -30,8 +31,8 @@ def table(ctx: click.Context, tags: Optional[List[str]], device: Optional[str], """ANTA command to check network states with table result""" print_settings(ctx) results = ResultManager() - with anta_progress_bar() as progress: - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress)) + with anta_progress_bar() as AntaTest.progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) print_table(results=results, device=device, test=test) # TODO make a util method to avoid repeating the same three line @@ -55,8 +56,8 @@ def json(ctx: click.Context, tags: Optional[List[str]], output: Optional[pathlib """ANTA command to check network state with JSON result""" print_settings(ctx) results = ResultManager() - with anta_progress_bar() as progress: - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress)) + with anta_progress_bar() as AntaTest.progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) print_json(results=results, output=output) ignore_status = ctx.obj["ignore_status"] @@ -73,8 +74,8 @@ def text(ctx: click.Context, tags: Optional[List[str]], search: Optional[str], s """ANTA command to check network states with text result""" print_settings(ctx) results = ResultManager() - with anta_progress_bar() as progress: - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress)) + with anta_progress_bar() as AntaTest.progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) print_text(results=results, search=search, skip_error=skip_error) ignore_status = ctx.obj["ignore_status"] @@ -105,8 +106,8 @@ def tpl_report(ctx: click.Context, tags: Optional[List[str]], template: pathlib. """ANTA command to check network state with templated report""" print_settings(ctx, template, output) results = ResultManager() - with anta_progress_bar() as progress: - asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress)) + with anta_progress_bar() as AntaTest.progress: + asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags)) print_jinja(results=results, template=template, output=output) ignore_status = ctx.obj["ignore_status"] diff --git a/anta/decorators.py b/anta/decorators.py index a75514c83..3b8defd3a 100644 --- a/anta/decorators.py +++ b/anta/decorators.py @@ -3,14 +3,16 @@ """ from functools import wraps from typing import Any, Callable, Dict, List, TypeVar, cast +import logging -from anta.models import AntaCommand +from anta.models import AntaCommand, AntaTest from anta.result_manager.models import TestResult from anta.tools.misc import exc_to_str # TODO - should probably use mypy Awaitable in some places rather than this everywhere - @gmuloc F = TypeVar("F", bound=Callable[..., Any]) +logger = logging.getLogger(__name__) def skip_on_platforms(platforms: List[str]) -> Callable[[F], F]: """ @@ -35,12 +37,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: anta_test = args[0] if anta_test.result.result != "unset": - anta_test.update_progress() + AntaTest.update_progress() return anta_test.result if anta_test.device.hw_model in platforms: anta_test.result.is_skipped(f"{anta_test.__class__.__name__} test is not supported on {anta_test.device.hw_model}.") - anta_test.update_progress() + AntaTest.update_progress() return anta_test.result return await function(*args, **kwargs) @@ -73,7 +75,7 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: anta_test = args[0] if anta_test.result.result != "unset": - anta_test.update_progress() + AntaTest.update_progress() return anta_test.result if family == "ipv4": @@ -95,12 +97,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult: return anta_test.result if "vrfs" not in command.json_output: anta_test.result.is_skipped(f"no BGP configuration for {family} on this device") - anta_test.update_progress() + AntaTest.update_progress() return anta_test.result if len(bgp_vrfs := command.json_output["vrfs"]) == 0 or len(bgp_vrfs["default"]["peers"]) == 0: # No VRF anta_test.result.is_skipped(f"no {family} peer on this device") - anta_test.update_progress() + AntaTest.update_progress() return anta_test.result return await function(*args, **kwargs) diff --git a/anta/models.py b/anta/models.py index 14b34513f..d0e22ccf7 100644 --- a/anta/models.py +++ b/anta/models.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Coroutine, Dict, List, Literal, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Coroutine, Dict, List, Literal, Optional, TypeVar, Union from pydantic import BaseModel, ConfigDict, conint from rich.progress import Progress, TaskID @@ -152,6 +152,8 @@ class AntaTest(ABC): commands: ClassVar[list[AntaCommand]] # TODO - today we support only one template per Test template: ClassVar[AntaTemplate] + progress: Optional[Progress] = None + nrfu_task: Optional[TaskID] = None # Optional class attributes test_filters: ClassVar[list[AntaTestFilter]] @@ -162,8 +164,7 @@ def __init__( template_params: list[dict[str, Any]] | None = None, # TODO document very well the order of eos_data eos_data: list[dict[Any, Any] | str] | None = None, - labels: list[str] | None = None, - progress: Optional[Progress] = None, + labels: list[str] | None = None ): """Class constructor""" # Accept 6 input arguments @@ -173,7 +174,6 @@ def __init__( self.result: TestResult = TestResult(name=device.name, test=self.name, test_category=self.categories, test_description=self.description) self.labels: List[str] = labels or [] self.instance_commands: List[AntaCommand] = [] - self.progress = progress # TODO - check optimization for deepcopy # Generating instance_commands from list of commands and template @@ -290,20 +290,18 @@ async def wrapper( self.logger.error(f"{message}: {exc_to_str(e)}") self.result.is_error(exc_to_str(e)) - self.update_progress() + AntaTest.update_progress() return self.result return wrapper - def update_progress(self) -> None: + @classmethod + def update_progress(cls) -> None: """ - Update progress bar if it exists + Update progress bar for all AntaTest objects if it exists """ - if self.progress: - # TODO this is hacky because we only have one task.. - # Should be id 0 - casting for mypy - nrfu_task: TaskID = cast(TaskID, 0) - self.progress.update(nrfu_task, advance=1) + if cls.progress and cls.nrfu_task is not None: + cls.progress.update(cls.nrfu_task, advance=1) @abstractmethod def test(self) -> Coroutine[Any, Any, TestResult]: diff --git a/anta/runner.py b/anta/runner.py index 792a50daa..bbf77449a 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -9,6 +9,7 @@ from rich.progress import Progress +from anta.models import AntaTest from anta import __DEBUG__ from anta.inventory import AntaInventory from anta.result_manager import ResultManager @@ -26,8 +27,7 @@ async def main( inventory: AntaInventory, tests: List[Tuple[Callable[..., TestResult], Dict[Any, Any]]], tags: Optional[List[str]] = None, - established_only: bool = True, - progress: Optional[Progress] = None, + established_only: bool = True ) -> None: """ Main coroutine to run ANTA. @@ -62,7 +62,7 @@ async def main( template_params = test[1].get(TEST_TPL_PARAMS) try: # Instantiate AntaTest object - test_instance = test[0](device=device, template_params=template_params, progress=progress) + test_instance = test[0](device=device, template_params=template_params) coros.append(test_instance.test(eos_data=None, **test_params)) except Exception as e: # pylint: disable=broad-exception-caught message = "Error when creating ANTA tests" @@ -71,8 +71,8 @@ async def main( else: logger.error(f"{message}: {exc_to_str(e)}") - if progress is not None: - progress.add_task("Running NRFU Tests...", total=len(coros)) + if AntaTest.progress is not None: + AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coros)) logger.info("Running ANTA tests...") res = await asyncio.gather(*coros, return_exceptions=True) From 1e8a8924173dab1c0016d57c088a852f6f431603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthieu=20T=C3=A2che?= Date: Fri, 7 Jul 2023 10:05:18 +0200 Subject: [PATCH 6/8] lint --- anta/decorators.py | 2 -- anta/runner.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/anta/decorators.py b/anta/decorators.py index 3b8defd3a..a22313f33 100644 --- a/anta/decorators.py +++ b/anta/decorators.py @@ -3,7 +3,6 @@ """ from functools import wraps from typing import Any, Callable, Dict, List, TypeVar, cast -import logging from anta.models import AntaCommand, AntaTest from anta.result_manager.models import TestResult @@ -12,7 +11,6 @@ # TODO - should probably use mypy Awaitable in some places rather than this everywhere - @gmuloc F = TypeVar("F", bound=Callable[..., Any]) -logger = logging.getLogger(__name__) def skip_on_platforms(platforms: List[str]) -> Callable[[F], F]: """ diff --git a/anta/runner.py b/anta/runner.py index bbf77449a..4d8657a4f 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -7,8 +7,6 @@ import logging from typing import Any, Callable, Dict, List, Optional, Tuple -from rich.progress import Progress - from anta.models import AntaTest from anta import __DEBUG__ from anta.inventory import AntaInventory From 48ce1b49b0c513293f18deeaa8791152c7f43b7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthieu=20T=C3=A2che?= Date: Fri, 7 Jul 2023 10:08:23 +0200 Subject: [PATCH 7/8] black --- anta/cli/nrfu/commands.py | 2 +- anta/models.py | 2 +- anta/runner.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/anta/cli/nrfu/commands.py b/anta/cli/nrfu/commands.py index 09d6fc1e1..d53be8f17 100644 --- a/anta/cli/nrfu/commands.py +++ b/anta/cli/nrfu/commands.py @@ -13,8 +13,8 @@ import click from anta.cli.utils import parse_tags, return_code -from anta.result_manager import ResultManager from anta.models import AntaTest +from anta.result_manager import ResultManager from anta.runner import main from .utils import anta_progress_bar, print_jinja, print_json, print_settings, print_table, print_text diff --git a/anta/models.py b/anta/models.py index d0e22ccf7..8e3d36b8f 100644 --- a/anta/models.py +++ b/anta/models.py @@ -164,7 +164,7 @@ def __init__( template_params: list[dict[str, Any]] | None = None, # TODO document very well the order of eos_data eos_data: list[dict[Any, Any] | str] | None = None, - labels: list[str] | None = None + labels: list[str] | None = None, ): """Class constructor""" # Accept 6 input arguments diff --git a/anta/runner.py b/anta/runner.py index 4d8657a4f..18be2d5d2 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -7,9 +7,9 @@ import logging from typing import Any, Callable, Dict, List, Optional, Tuple -from anta.models import AntaTest from anta import __DEBUG__ from anta.inventory import AntaInventory +from anta.models import AntaTest from anta.result_manager import ResultManager from anta.result_manager.models import TestResult from anta.tools.misc import exc_to_str @@ -25,7 +25,7 @@ async def main( inventory: AntaInventory, tests: List[Tuple[Callable[..., TestResult], Dict[Any, Any]]], tags: Optional[List[str]] = None, - established_only: bool = True + established_only: bool = True, ) -> None: """ Main coroutine to run ANTA. From 3f651f5899b0e95604c022d05b90e1b2dbf3ecde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthieu=20T=C3=A2che?= Date: Fri, 7 Jul 2023 10:08:49 +0200 Subject: [PATCH 8/8] fix logical expression --- anta/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/anta/models.py b/anta/models.py index 8e3d36b8f..7e5b22f0e 100644 --- a/anta/models.py +++ b/anta/models.py @@ -300,7 +300,7 @@ def update_progress(cls) -> None: """ Update progress bar for all AntaTest objects if it exists """ - if cls.progress and cls.nrfu_task is not None: + if cls.progress and (cls.nrfu_task is not None): cls.progress.update(cls.nrfu_task, advance=1) @abstractmethod