diff --git a/anta/device.py b/anta/device.py index 087f3b57b..4ebd4abbb 100644 --- a/anta/device.py +++ b/anta/device.py @@ -16,12 +16,12 @@ from aiocache import Cache from aiocache.plugins import HitMissRatioPlugin from asyncssh import SSHClientConnection, SSHClientConnectionOptions -from httpx import ConnectError, HTTPError, TimeoutException +from httpx import ConnectError, HTTPError, Limits, TimeoutException -import asynceapi from anta import __DEBUG__ from anta.logger import anta_log_exception, exc_to_str from anta.models import AntaCommand +from asynceapi import Device, EapiCommandError if TYPE_CHECKING: from collections.abc import Iterator @@ -117,7 +117,7 @@ def __rich_repr__(self) -> Iterator[tuple[str, Any]]: yield "disable_cache", self.cache is None @abstractmethod - async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: + async def _collect(self, anta_commands: list[AntaCommand], *, req_format: Literal["json", "text"] = "json", req_id: str | None = None) -> None: """Collect device command output. This abstract coroutine can be used to implement any command collection method @@ -136,38 +136,7 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No collection_id: An identifier used to build the eAPI request ID. """ - async def collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: - """Collect the output for a specified command. - - When caching is activated on both the device and the command, - this method prioritizes retrieving the output from the cache. In cases where the output isn't cached yet, - it will be freshly collected and then stored in the cache for future access. - The method employs asynchronous locks based on the command's UID to guarantee exclusive access to the cache. - - When caching is NOT enabled, either at the device or command level, the method directly collects the output - via the private `_collect` method without interacting with the cache. - - Parameters - ---------- - command: The command to collect. - collection_id: An identifier used to build the eAPI request ID. - """ - # Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough - # https://github.com/pylint-dev/pylint/issues/7258 - if self.cache is not None and self.cache_locks is not None and command.use_cache: - async with self.cache_locks[command.uid]: - cached_output = await self.cache.get(command.uid) # pylint: disable=no-member - - if cached_output is not None: - logger.debug("Cache hit for %s on %s", command.command, self.name) - command.output = cached_output - else: - await self._collect(command=command, collection_id=collection_id) - await self.cache.set(command.uid, command.output) # pylint: disable=no-member - else: - await self._collect(command=command, collection_id=collection_id) - - async def collect_commands(self, commands: list[AntaCommand], *, collection_id: str | None = None) -> None: + async def collect_commands(self, anta_commands: list[AntaCommand], *, req_format: Literal["text", "json"] = "json", req_id: str) -> None: """Collect multiple commands. Parameters @@ -175,7 +144,33 @@ async def collect_commands(self, commands: list[AntaCommand], *, collection_id: commands: The commands to collect. collection_id: An identifier used to build the eAPI request ID. """ - await asyncio.gather(*(self.collect(command=command, collection_id=collection_id) for command in commands)) + # FIXME: Avoid querying the cache for the initial commands that are not cached. + commands_to_collect = [] + + # FIXME: Don't loop over commands if the cache is disabled + for command in anta_commands: + if self.cache is not None and self.cache_locks is not None and command.use_cache: + async with self.cache_locks[command.uid]: + # Need to disable pylint no-member as Cache is a proxy class and pylint is not smart enough + # https://github.com/pylint-dev/pylint/issues/7258 + cached_output = await self.cache.get(command.uid) # pylint: disable=no-member + + if cached_output is not None: + logger.debug("Cache hit for %s on %s", command.command, self.name) + command.output = cached_output + else: + commands_to_collect.append(command) + else: + commands_to_collect.append(command) + + # Collect the batch of commands that are not cached + if commands_to_collect: + await self._collect(commands_to_collect, req_format=req_format, req_id=req_id) + # Cache the outputs of the collected commands + for command in commands_to_collect: + if self.cache is not None and self.cache_locks is not None and command.use_cache: + async with self.cache_locks[command.uid]: + await self.cache.set(command.uid, command.output) # pylint: disable=no-member @abstractmethod async def refresh(self) -> None: @@ -271,7 +266,8 @@ def __init__( raise ValueError(message) self.enable = enable self._enable_password = enable_password - self._session: asynceapi.Device = asynceapi.Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout) + # TODO: Move the max_connections setting change to a separate PR + self._session: Device = Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout, limits=Limits(max_connections=7)) ssh_params: dict[str, Any] = {} if insecure: ssh_params["known_hosts"] = None @@ -306,7 +302,79 @@ def _keys(self) -> tuple[Any, ...]: """ return (self._session.host, self._session.port) - async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: # noqa: C901 function is too complex - because of many required except blocks #pylint: disable=line-too-long + async def _handle_eapi_command_error(self, exception: EapiCommandError, anta_commands: list[AntaCommand], *, req_format: str, req_id: str) -> None: + """Handle EapiCommandError exceptions.""" + # Populate the output attribute of the AntaCommand objects with the commands that passed + passed_outputs = exception.passed[1:] if self.enable else exception.passed + for anta_command, output in zip(anta_commands, passed_outputs): + anta_command.output = output + + # Populate the errors attribute of the AntaCommand object of the command that failed + err_at = exception.err_at - 1 if self.enable else exception.err_at + anta_command = anta_commands[err_at] + anta_command.errors = exception.errors + if anta_command.requires_privileges: + logger.error( + "Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", + anta_command.command, + self.name, + ) + + if anta_command.supported: + error_message = exception.errors[0] if len(exception.errors) == 1 else exception.errors + logger.error( + "Command '%s' failed on %s: %s", + anta_command.command, + self.name, + error_message, + ) + else: + logger.error("Command '%s' is not supported on %s (%s).", anta_command.command, self.name, self.hw_model) + + # Collect the commands that were not executed + await self._collect(anta_commands=anta_commands[err_at + 1 :], req_format=req_format, req_id=req_id) + + def _handle_timeout_exception(self, exception: TimeoutException, anta_commands: list[AntaCommand]) -> None: + """Handle TimeoutException exceptions.""" + # FIXME: Handle timeouts more gracefully + for anta_command in anta_commands: + anta_command.errors = [exc_to_str(exception)] + + timeouts = self._session.timeout.as_dict() + logger.error( + "%s occurred while sending commands to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s", + exc_to_str(exception), + self.name, + timeouts["connect"], + timeouts["read"], + timeouts["write"], + timeouts["pool"], + ) + + def _handle_connect_os_error(self, exception: ConnectError | OSError, anta_commands: list[AntaCommand]) -> None: + """Handle HTTPX ConnectError and OSError exceptions.""" + # FIXME: Handle connection errors more gracefully + for anta_command in anta_commands: + anta_command.errors = [exc_to_str(exception)] + + if (isinstance(exc := exception.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance( + os_error := exception, OSError + ): + if isinstance(os_error.__cause__, OSError): + os_error = os_error.__cause__ + logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error) + else: + anta_log_exception(exception, f"An error occurred while issuing an eAPI request to {self.name}", logger) + + def _handle_http_error(self, exception: HTTPError, anta_commands: list[AntaCommand]) -> None: + """Handle HTTPError exceptions.""" + # FIXME: Handle HTTP errors more gracefully + for anta_command in anta_commands: + anta_command.errors = [exc_to_str(exception)] + + anta_log_exception(exception, f"An error occurred while issuing an eAPI request to {self.name}", logger) + + async def _collect(self, anta_commands: list[AntaCommand], *, req_format: Literal["json", "text"] = "json", req_id: str) -> None: """Collect device command output from EOS using aio-eapi. Supports outformat `json` and `text` as output structure. @@ -318,65 +386,43 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No command: The command to collect. collection_id: An identifier used to build the eAPI request ID. """ - commands: list[dict[str, str | int]] = [] + commands = [ + {"cmd": anta_command.command, "revision": anta_command.revision} if anta_command.revision else {"cmd": anta_command.command} + for anta_command in anta_commands + ] + if self.enable and self._enable_password is not None: - commands.append( - { - "cmd": "enable", - "input": str(self._enable_password), - }, - ) + commands.insert(0, {"cmd": "enable", "input": str(self._enable_password)}) elif self.enable: # No password - commands.append({"cmd": "enable"}) - commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}] + commands.insert(0, {"cmd": "enable"}) + try: - response: list[dict[str, Any] | str] = await self._session.cli( + response = await self._session.cli( commands=commands, - ofmt=command.ofmt, - version=command.version, - req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}", - ) # type: ignore[assignment] # multiple commands returns a list - # Do not keep response of 'enable' command - command.output = response[-1] - except asynceapi.EapiCommandError as e: + ofmt=req_format, + req_id=f"ANTA-{req_id}", + ) + # If enable was used, exclude the first element from the response + if self.enable: + response = response[1:] + + # Populate the output attribute of the AntaCommand objects + for anta_command, command_output in zip(anta_commands, response): + anta_command.output = command_output + + except EapiCommandError as e: # This block catches exceptions related to EOS issuing an error. - command.errors = e.errors - if command.requires_privileges: - logger.error( - "Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", command.command, self.name - ) - if command.supported: - logger.error("Command '%s' failed on %s: %s", command.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors) - else: - logger.debug("Command '%s' is not supported on '%s' (%s)", command.command, self.name, self.hw_model) + await self._handle_eapi_command_error(e, anta_commands, req_format=req_format, req_id=req_id) except TimeoutException as e: - # This block catches Timeout exceptions. - command.errors = [exc_to_str(e)] - timeouts = self._session.timeout.as_dict() - logger.error( - "%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s", - exc_to_str(e), - self.name, - timeouts["connect"], - timeouts["read"], - timeouts["write"], - timeouts["pool"], - ) - except (ConnectError, OSError) as e: - # This block catches OSError and socket issues related exceptions. - command.errors = [exc_to_str(e)] - if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := e, OSError): # pylint: disable=no-member - if isinstance(os_error.__cause__, OSError): - os_error = os_error.__cause__ - logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error) - else: - anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger) + # This block catches exceptions related to the timeout of the request. + self._handle_timeout_exception(e, anta_commands) + except ConnectError as e: + # This block catches exceptions related to the connection to the device. + self._handle_connect_os_error(e, anta_commands) except HTTPError as e: - # This block catches most of the httpx Exceptions and logs a general message. - command.errors = [exc_to_str(e)] - anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger) - logger.debug("%s: %s", self.name, command) + # This block catches exceptions related to the HTTP connection. + self._handle_http_error(e, anta_commands) async def refresh(self) -> None: """Update attributes of an AsyncEOSDevice instance. @@ -389,8 +435,8 @@ async def refresh(self) -> None: logger.debug("Refreshing device %s", self.name) self.is_online = await self._session.check_connection() if self.is_online: - show_version = AntaCommand(command="show version") - await self._collect(show_version) + show_version = AntaCommand(command="show version", revision=1) + await self._collect([show_version], req_format="json", req_id="Refresh") if not show_version.collected: logger.warning("Cannot get hardware information from device %s", self.name) else: diff --git a/anta/models.py b/anta/models.py index e2cf49857..01857057b 100644 --- a/anta/models.py +++ b/anta/models.py @@ -5,11 +5,13 @@ from __future__ import annotations +import asyncio import hashlib import logging import re from abc import ABC, abstractmethod -from functools import wraps +from collections import defaultdict +from functools import cached_property, wraps from string import Formatter from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar @@ -25,6 +27,7 @@ from rich.progress import Progress, TaskID + from anta.catalog import AntaTestDefinition from anta.device import AntaDevice F = TypeVar("F", bound=Callable[..., Any]) @@ -165,12 +168,11 @@ class AntaCommand(BaseModel): params: AntaParamsBaseModel = AntaParamsBaseModel() use_cache: bool = True - @property + @cached_property def uid(self) -> str: """Generate a unique identifier for this command.""" uid_str = f"{self.command}_{self.version}_{self.revision or 'NA'}_{self.ofmt}" - # Ignoring S324 probable use of insecure hash function - sha1 is enough for our needs. - return hashlib.sha1(uid_str.encode()).hexdigest() # noqa: S324 + return hashlib.sha256(uid_str.encode()).hexdigest() @property def json_output(self) -> dict[str, Any]: @@ -302,6 +304,8 @@ def test(self) -> None: instance_commands: List of AntaCommand instances of this test result: TestResult instance representing the result of this test logger: Python logger for this test instance + event: asyncio.Event used by the AntaTestManager to signal the test that all commands have been collected + so that it can start running the validation """ # Mandatory class attributes @@ -398,6 +402,7 @@ def __init__( categories=self.categories, description=self.description, ) + self.event: asyncio.Event | None = None self._init_inputs(inputs) if self.result.result == "unset": self._init_commands(eos_data) @@ -523,19 +528,6 @@ def blocked(self) -> bool: state = True return state - async def collect(self) -> None: - """Collect outputs of all commands of this test class from the device of this test instance.""" - try: - if self.blocked is False: - await self.device.collect_commands(self.instance_commands, collection_id=self.name) - except Exception as e: # pylint: disable=broad-exception-caught - # device._collect() is user-defined code. - # We need to catch everything if we want the AntaTest object - # to live until the reporting - message = f"Exception raised while collecting commands for test {self.name} (on device {self.device.name})" - anta_log_exception(e, message, self.logger) - self.result.is_error(message=exc_to_str(e)) - @staticmethod def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: """Decorate the `test()` method in child classes. @@ -552,7 +544,6 @@ def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: async def wrapper( self: AntaTest, eos_data: list[dict[Any, Any] | str] | None = None, - **kwargs: dict[str, Any], ) -> TestResult: """Inner function for the anta_test decorator. @@ -561,7 +552,6 @@ async def wrapper( self: The test instance. eos_data: Populate outputs of the test commands instead of collecting from devices. This list must have the same length and order than the `instance_commands` instance attribute. - kwargs: Any keyword argument to pass to the test. Returns ------- @@ -576,9 +566,12 @@ async def wrapper( self.save_commands_data(eos_data) self.logger.debug("Test %s initialized with input data %s", self.name, eos_data) - # If some data is missing, try to collect + # Wait until all commands have been collected by the manager before running the test + logger.debug("Waiting for all commands to be collected for %s on device %s", self.name, self.device.name) + await self.event.wait() + + self.logger.debug("Starting validation for %s on device %s", self.name, self.device.name) if not self.collected: - await self.collect() if self.result.result != "unset": AntaTest.update_progress() return self.result @@ -594,8 +587,9 @@ async def wrapper( AntaTest.update_progress() return self.result + # Run the test in a separate thread to avoid blocking the event loop try: - function(self, **kwargs) + await asyncio.to_thread(function, self) except Exception as e: # pylint: disable=broad-exception-caught # test() is user-defined code. # We need to catch everything if we want the AntaTest object @@ -606,6 +600,8 @@ async def wrapper( # TODO: find a correct way to time test execution AntaTest.update_progress() + + logger.debug("Validation completed for %s on device %s", self.name, self.device.name) return self.result return wrapper @@ -617,7 +613,7 @@ def update_progress(cls: type[AntaTest]) -> None: cls.progress.update(cls.nrfu_task, advance=1) @abstractmethod - def test(self) -> Coroutine[Any, Any, TestResult]: + def test(self) -> None: """Core of the test logic. This is an abstractmethod that must be implemented by child classes. @@ -636,3 +632,92 @@ def test(self) -> None: ``` """ + + +class AntaTestManager: + """TODO: Add docstring. + + # FIXME: Handle different batch sizes for different tests + # FIXME: Handle decorators that skip tests. For now commands are still sent to the device. + """ + + def __init__(self, device: AntaDevice, batch_size: int) -> None: + """TODO: Add docstring.""" + self.device = device + self.batch_size = batch_size + self.completed_command_ids: set[int] = set() + self.test_map: defaultdict[AntaTest, set[int]] = defaultdict(set) + self.events: dict[AntaTest, asyncio.Event] = {} + + async def run_tests(self, test_definitions: set[AntaTestDefinition]) -> None: + json_commands: list[AntaCommand] = [] + text_commands: list[AntaCommand] = [] + test_tasks = [] + batch_tasks = [] + + for test_definition in test_definitions: + try: + test_instance = test_definition.test(device=self.device, inputs=test_definition.inputs) + # Skip the test if it has blocked commands + if test_instance.blocked is True: + continue + + test_instance.event = asyncio.Event() + self.events[test_instance] = test_instance.event + + for command in test_instance.instance_commands: + self.test_map[test_instance].add(id(command)) + if command.ofmt == "json": + json_commands.append(command) + elif command.ofmt == "text": + text_commands.append(command) + + test_tasks.append(asyncio.create_task(test_instance.test())) + except Exception as exc: # pylint: disable=broad-exception-caught + # An AntaTest instance is potentially user-defined code. + # We need to catch everything and exit gracefully with an error message. + message = "\n".join( + [ + f"There is an error when creating test {test_definition.test.module}.{test_definition.test.__name__}.", + f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", + ], + ) + anta_log_exception(exc, message, logger) + + total_commands = len(json_commands) + len(text_commands) + logger.debug("Total commands to process for %s: %d", self.device.name, total_commands) + + # Create batches for JSON commands + json_batches = [json_commands[i : i + self.batch_size] for i in range(0, len(json_commands), self.batch_size)] + # Create batches for text commands + text_batches = [text_commands[i : i + self.batch_size] for i in range(0, len(text_commands), self.batch_size)] + + logger.debug("Number of JSON batches for %s: %d", self.device.name, len(json_batches)) + logger.debug("Number of text batches for %s: %d", self.device.name, len(text_batches)) + + # Process JSON batches + for i, batch in enumerate(json_batches): + batch_tasks.append(asyncio.create_task(self.process_batch(i, batch, "json"))) + + # Process text batches + for i, batch in enumerate(text_batches): + batch_tasks.append(asyncio.create_task(self.process_batch(i, batch, "text"))) + + # Make sure all batch tasks are completed + await asyncio.gather(*batch_tasks) + + # Make sure all test tasks are completed and return the results + return await asyncio.gather(*test_tasks) + + async def process_batch(self, batch_id: int, batch: list[AntaCommand], batch_format: Literal["text", "json"] = "json") -> None: + await self.device.collect_commands(batch, req_format=batch_format, req_id=f"Batch #{batch_id}") + for command in batch: + cmd_id = id(command) + self.completed_command_ids.add(cmd_id) + self.check_test_completion(cmd_id) + + def check_test_completion(self, cmd_id: int) -> None: + for test_instance, cmd_ids in self.test_map.items(): + if cmd_id in cmd_ids and cmd_ids.issubset(self.completed_command_ids): + logger.debug("All commands completed for %s", test_instance.name) + self.events[test_instance].set() diff --git a/anta/runner.py b/anta/runner.py index df4c70cc4..d0af7463c 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -10,11 +10,11 @@ import os import resource from collections import defaultdict -from typing import TYPE_CHECKING, Any +from itertools import chain +from typing import TYPE_CHECKING -from anta import GITHUB_SUGGESTION -from anta.logger import anta_log_exception, exc_to_str -from anta.models import AntaTest +from anta.logger import exc_to_str +from anta.models import AntaTest, AntaTestManager from anta.tools import Catchtime, cprofile if TYPE_CHECKING: @@ -24,7 +24,6 @@ from anta.device import AntaDevice from anta.inventory import AntaInventory from anta.result_manager import ResultManager - from anta.result_manager.models import TestResult logger = logging.getLogger(__name__) @@ -157,36 +156,6 @@ def prepare_tests( return device_to_tests -def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]]) -> list[Coroutine[Any, Any, TestResult]]: - """Get the coroutines for the ANTA run. - - Parameters - ---------- - selected_tests: A mapping of devices to the tests to run. The selected tests are generated by the `prepare_tests` function. - - Returns - ------- - The list of coroutines to run. - """ - coros = [] - for device, test_definitions in selected_tests.items(): - for test in test_definitions: - try: - test_instance = test.test(device=device, inputs=test.inputs) - coros.append(test_instance.test()) - except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught - # An AntaTest instance is potentially user-defined code. - # We need to catch everything and exit gracefully with an error message. - message = "\n".join( - [ - f"There is an error when creating test {test.test.module}.{test.test.__name__}.", - f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", - ], - ) - anta_log_exception(e, message, logger) - return coros - - @cprofile() async def main( # noqa: PLR0913 manager: ResultManager, @@ -234,6 +203,12 @@ async def main( # noqa: PLR0913 if selected_tests is None: return + coros: list[Coroutine] = [] + + for device, test_definitions in selected_tests.items(): + test_manager = AntaTestManager(device=device, batch_size=100) + coros.append(test_manager.run_tests(test_definitions)) + run_info = ( "--- ANTA NRFU Run Information ---\n" f"Number of devices: {len(inventory)} ({len(selected_inventory)} established)\n" @@ -251,20 +226,18 @@ async def main( # noqa: PLR0913 "Please consult the ANTA FAQ." ) - coroutines = get_coroutines(selected_tests) - if dry_run: logger.info("Dry-run mode, exiting before running the tests.") - for coro in coroutines: + for coro in coros: coro.close() return if AntaTest.progress is not None: - AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coroutines)) + AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=catalog.final_tests_count) with Catchtime(logger=logger, message="Running ANTA tests"): - test_results = await asyncio.gather(*coroutines) - for r in test_results: - manager.add(r) + results = chain.from_iterable(await asyncio.gather(*coros)) + for result in results: + manager.add(result) log_cache_statistics(selected_inventory.devices) diff --git a/asynceapi/device.py b/asynceapi/device.py index ca206d3e4..5c34b8e7c 100644 --- a/asynceapi/device.py +++ b/asynceapi/device.py @@ -9,6 +9,7 @@ from __future__ import annotations +from json import loads from socket import getservbyname from typing import TYPE_CHECKING, Any @@ -244,7 +245,9 @@ async def jsonrpc_exec(self, jsonrpc: dict[str, Any]) -> list[dict[str, Any] | s commands = jsonrpc["params"]["cmds"] ofmt = jsonrpc["params"]["format"] - get_output = (lambda _r: _r["output"]) if ofmt == "text" else (lambda _r: _r) + # Return the correct output format based on the requested ofmt. + def get_output(response: dict[str, Any]) -> str | dict[str, Any]: + return response["output"] if ofmt == "text" else loads(response) if isinstance(response, str) else response # if there are no errors then return the list of command results. if (err_data := body.get("error")) is None: @@ -272,9 +275,12 @@ async def jsonrpc_exec(self, jsonrpc: dict[str, Any]) -> list[dict[str, Any] | s err_at = len_data - 1 err_msg = err_data["message"] + # FIXME: EapiCommandError exception only supports complex commands (dict) and not simple commands (str) + # https://github.com/aristanetworks/anta/issues/718 raise EapiCommandError( passed=[get_output(cmd_data[cmd_i]) for cmd_i, cmd in enumerate(commands[:err_at])], failed=commands[err_at]["cmd"], + err_at=err_at, errors=cmd_data[err_at]["errors"], errmsg=err_msg, not_exec=commands[err_at + 1 :], diff --git a/asynceapi/errors.py b/asynceapi/errors.py index 020d3dc2f..314a08bc7 100644 --- a/asynceapi/errors.py +++ b/asynceapi/errors.py @@ -19,15 +19,17 @@ class EapiCommandError(RuntimeError): ---------- failed: the failed command errmsg: a description of the failure reason + err_at: the index of the command that failed errors: the command failure details passed: a list of command results of the commands that passed not_exec: a list of commands that were not executed """ - def __init__(self, failed: str, errors: list[str], errmsg: str, passed: list[str | dict[str, Any]], not_exec: list[dict[str, Any]]) -> None: # pylint: disable=too-many-arguments + def __init__(self, failed: str, err_at: int, errors: list[str], errmsg: str, passed: list[str | dict[str, Any]], not_exec: list[dict[str, Any]]) -> None: # pylint: disable=too-many-arguments """Initialize for the EapiCommandError exception.""" self.failed = failed self.errmsg = errmsg + self.err_at = err_at self.errors = errors self.passed = passed self.not_exec = not_exec