From 30382013f04488d1416f5fdf2fd9ee951c354653 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Tue, 11 Jun 2024 11:43:40 -0400 Subject: [PATCH 01/12] WIP refactor _collect() --- anta/catalog.py | 13 +- anta/decorators.py | 4 +- anta/device.py | 337 ++++++++++++++++++++++++++++++++------------ anta/models.py | 27 ++-- anta/runner.py | 27 +--- asynceapi/device.py | 1 + asynceapi/errors.py | 4 +- 7 files changed, 281 insertions(+), 132 deletions(-) diff --git a/anta/catalog.py b/anta/catalog.py index 142640ecb..7782d9f83 100644 --- a/anta/catalog.py +++ b/anta/catalog.py @@ -10,6 +10,7 @@ import math from collections import defaultdict from inspect import isclass +from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union @@ -372,18 +373,20 @@ def from_list(data: ListAntaTestTuples) -> AntaCatalog: raise return AntaCatalog(tests) - def merge(self, catalog: AntaCatalog) -> AntaCatalog: - """Merge two AntaCatalog instances. + @staticmethod + def merge(catalogs: list[AntaCatalog]) -> AntaCatalog: + """Merge multiple AntaCatalog instances. Args: ---- - catalog: AntaCatalog instance to merge to this instance. + catalogs: List of AntaCatalog instances to merge. Returns ------- - A new AntaCatalog instance containing the tests of the two instances. + A new AntaCatalog instance containing the tests of all the instances. """ - return AntaCatalog(tests=self.tests + catalog.tests) + combined_tests = list(chain(*(catalog.tests for catalog in catalogs))) + return AntaCatalog(tests=combined_tests) def dump(self) -> AntaCatalogFile: """Return an AntaCatalogFile instance from this AntaCatalog instance. diff --git a/anta/decorators.py b/anta/decorators.py index dc57e13ec..fd2a7dfae 100644 --- a/anta/decorators.py +++ b/anta/decorators.py @@ -88,7 +88,7 @@ def decorator(function: F) -> F: """ @wraps(function) - async def wrapper(*args: Any, **kwargs: Any) -> TestResult: + def wrapper(*args: Any, **kwargs: Any) -> TestResult: """Check the device's hardware model and conditionally run or skip the test. This wrapper inspects the hardware model of the device the test is run on. @@ -105,7 +105,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> TestResult: AntaTest.update_progress() return anta_test.result - return await function(*args, **kwargs) + return function(*args, **kwargs) return cast(F, wrapper) diff --git a/anta/device.py b/anta/device.py index f0ec6a00c..5a048701d 100644 --- a/anta/device.py +++ b/anta/device.py @@ -5,21 +5,22 @@ from __future__ import annotations -import asyncio import logging from abc import ABC, abstractmethod +from asyncio import Lock from collections import defaultdict from typing import TYPE_CHECKING, Any, Literal +from uuid import uuid4 import asyncssh import httpcore from aiocache import Cache from aiocache.plugins import HitMissRatioPlugin +from asynceapi import Device, EapiCommandError from asyncssh import SSHClientConnection, SSHClientConnectionOptions from httpx import ConnectError, HTTPError, TimeoutException -import asynceapi -from anta import __DEBUG__ +from anta import __DEBUG__, GITHUB_SUGGESTION from anta.logger import anta_log_exception, exc_to_str from anta.models import AntaCommand @@ -27,6 +28,9 @@ from collections.abc import Iterator from pathlib import Path + from anta.catalog import AntaTestDefinition + from anta.result_manager.models import TestResult + logger = logging.getLogger(__name__) # Do not load the default keypairs multiple times due to a performance issue introduced in cryptography 37.0 @@ -70,7 +74,7 @@ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bo self.is_online: bool = False self.established: bool = False self.cache: Cache | None = None - self.cache_locks: defaultdict[str, asyncio.Lock] | None = None + self.cache_locks: defaultdict[str, Lock] | None = None # Initialize cache if not disabled if not disable_cache: @@ -92,7 +96,7 @@ def __hash__(self) -> int: def _init_cache(self) -> None: """Initialize cache for the device, can be overridden by subclasses to manipulate how it works.""" self.cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()]) - self.cache_locks = defaultdict(asyncio.Lock) + self.cache_locks = defaultdict(Lock) @property def cache_statistics(self) -> dict[str, Any] | None: @@ -117,7 +121,7 @@ def __rich_repr__(self) -> Iterator[tuple[str, Any]]: yield "disable_cache", self.cache is None @abstractmethod - async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: + async def _collect(self, anta_commands: list[AntaCommand], *, req_format: Literal["json", "text"] = "json", req_id: str | None = None) -> None: """Collect device command output. This abstract coroutine can be used to implement any command collection method @@ -136,38 +140,39 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No collection_id: An identifier used to build the eAPI request ID. """ - async def collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: - """Collect the output for a specified command. + def create_eapi_request_manager(self, test_definitions: set[AntaTestDefinition], *, batch_size: int) -> EapiRequestManager: + """""" + request_manager = EapiRequestManager(self, test_definitions) + request_manager.build_requests(batch_size) - When caching is activated on both the device and the command, - this method prioritizes retrieving the output from the cache. In cases where the output isn't cached yet, - it will be freshly collected and then stored in the cache for future access. - The method employs asynchronous locks based on the command's UID to guarantee exclusive access to the cache. + return request_manager - When caching is NOT enabled, either at the device or command level, the method directly collects the output - via the private `_collect` method without interacting with the cache. + async def run(self, request_manager: EapiRequestManager, *, req_id: str): + """""" + # Collect the command outputs from the device + anta_commands = request_manager.get_commands(req_id) + try: + await self.collect_commands(anta_commands, req_format="json", req_id=req_id) + except Exception as e: # pylint: disable=broad-exception-caught + # Since device._collect() is potentially user-defined code, we need to catch all exceptions + # and report the errors for every impacted test in the request. + message = f"Exception raised while collecting commands on device {self.name}" + anta_log_exception(e, message, logger) + for impacted_test in request_manager.requests[req_id]: + impacted_test.result.is_error(message=exc_to_str(e)) + + # Once all the command outputs from a request have been collected, run the validation tests + return self.validate_commands(request_manager, req_id=req_id) + + def validate_commands(self, request_manager: EapiRequestManager, *, req_id: str) -> list[TestResult]: + """""" + + test_instances = request_manager.requests[req_id] + # Each test() method of an AntaTest instance handles exceptions and return a TestResult object + return [test_instance.test() for test_instance in test_instances] - Args: - ---- - command: The command to collect. - collection_id: An identifier used to build the eAPI request ID. - """ - # Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough - # https://github.com/pylint-dev/pylint/issues/7258 - if self.cache is not None and self.cache_locks is not None and command.use_cache: - async with self.cache_locks[command.uid]: - cached_output = await self.cache.get(command.uid) # pylint: disable=no-member - - if cached_output is not None: - logger.debug("Cache hit for %s on %s", command.command, self.name) - command.output = cached_output - else: - await self._collect(command=command, collection_id=collection_id) - await self.cache.set(command.uid, command.output) # pylint: disable=no-member - else: - await self._collect(command=command, collection_id=collection_id) - async def collect_commands(self, commands: list[AntaCommand], *, collection_id: str | None = None) -> None: + async def collect_commands(self, anta_commands: list[AntaCommand], *, req_format: Literal["text", "json"] = "json", req_id: str) -> None: """Collect multiple commands. Args: @@ -175,7 +180,31 @@ async def collect_commands(self, commands: list[AntaCommand], *, collection_id: commands: The commands to collect. collection_id: An identifier used to build the eAPI request ID. """ - await asyncio.gather(*(self.collect(command=command, collection_id=collection_id) for command in commands)) + commands_to_collect = [] + + for command in anta_commands: + if self.cache is not None and self.cache_locks is not None and command.use_cache: + async with self.cache_locks[command.uid]: + # Need to disable pylint no-member as Cache is a proxy class and pylint is not smart enough + # https://github.com/pylint-dev/pylint/issues/7258 + cached_output = await self.cache.get(command.uid) # pylint: disable=no-member + + if cached_output is not None: + logger.debug("Cache hit for %s on %s", command.command, self.name) + command.output = cached_output + else: + commands_to_collect.append(command) + else: + commands_to_collect.append(command) + + # Collect the batch of commands that are not cached + if commands_to_collect: + await self._collect(commands_to_collect, req_format=req_format, req_id=req_id) + # Cache the outputs of the collected commands + for command in commands_to_collect: + if self.cache is not None and self.cache_locks is not None and command.use_cache: + async with self.cache_locks[command.uid]: + await self.cache.set(command.uid, command.output) # pylint: disable=no-member @abstractmethod async def refresh(self) -> None: @@ -271,7 +300,7 @@ def __init__( raise ValueError(message) self.enable = enable self._enable_password = enable_password - self._session: asynceapi.Device = asynceapi.Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout) + self._session: Device = Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout) ssh_params: dict[str, Any] = {} if insecure: ssh_params["known_hosts"] = None @@ -306,7 +335,77 @@ def _keys(self) -> tuple[Any, ...]: """ return (self._session.host, self._session.port) - async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: # noqa: C901 function is too complex - because of many required except blocks #pylint: disable=line-too-long + async def _handle_eapi_command_error(self, exception: EapiCommandError, anta_commands: list[AntaCommand], *, req_format: str, req_id: str) -> None: + """Handle EapiCommandError exceptions.""" + # Populate the output attribute of the AntaCommand objects with the commands that passed + passed_outputs = exception.passed[1:] if self.enable else exception.passed + for anta_command, output in zip(anta_commands, passed_outputs): + anta_command.output = output + + # Populate the errors attribute of the AntaCommand object of the command that failed + err_at = exception.err_at - 1 if self.enable else exception.err_at + anta_command = anta_commands[err_at] + anta_command.errors = exception.errors + if anta_command.requires_privileges: + logger.error( + "Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", + anta_command.command, + self.name, + ) + + if anta_command.supported: + error_message = exception.errors[0] if len(exception.errors) == 1 else exception.errors + logger.error( + "Command '%s' failed on %s: %s", + anta_command.command, + self.name, + error_message, + ) + else: + logger.error("Command '%s' is not supported on %s (%s).", anta_command.command, self.name, self.hw_model) + + # Collect the commands that were not executed + await self._collect(anta_commands=anta_commands[err_at + 1:], req_format=req_format, req_id=req_id) + + def _handle_timeout_exception(self, exception: TimeoutException, anta_commands: list[AntaCommand]) -> None: + """Handle TimeoutException exceptions.""" + # Populate the errors attribute of all the AntaCommand objects of the request since it failed + for anta_command in anta_commands: + anta_command.errors = [exc_to_str(exception)] + + timeouts = self._session.timeout.as_dict() + logger.error( + "%s occurred while sending commands to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s", + exc_to_str(exception), + self.name, + timeouts["connect"], + timeouts["read"], + timeouts["write"], + timeouts["pool"], + ) + + def _handle_connect_os_error(self, exception: ConnectError | OSError, anta_commands: list[AntaCommand]) -> None: + """Handle HTTPX ConnectError and OSError exceptions.""" + # Populate the errors attribute of all the AntaCommand objects of the request since it failed + for anta_command in anta_commands: + anta_command.errors = [exc_to_str(exception)] + + if (isinstance(exc := exception.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := exception, OSError): + if isinstance(os_error.__cause__, OSError): + os_error = os_error.__cause__ + logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error) + else: + anta_log_exception(exception, f"An error occurred while issuing an eAPI request to {self.name}", logger) + + def _handle_http_error(self, exception: HTTPError, anta_commands: list[AntaCommand]) -> None: + """Handle HTTPError exceptions.""" + # Populate the errors attribute of all the AntaCommand objects of the request since it failed + for anta_command in anta_commands: + anta_command.errors = [exc_to_str(exception)] + + anta_log_exception(exception, f"An error occurred while issuing an eAPI request to {self.name}", logger) + + async def _collect(self, anta_commands: list[AntaCommand], *, req_format: Literal["json", "text"] = "json", req_id: str) -> None: """Collect device command output from EOS using aio-eapi. Supports outformat `json` and `text` as output structure. @@ -318,65 +417,45 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No command: The command to collect. collection_id: An identifier used to build the eAPI request ID. """ - commands: list[dict[str, str | int]] = [] + # NOTE: `asynceapi` EapiCommandError exception only supports complex commands (dict) and not simple commands (str) + commands = [ + {"cmd": anta_command.command, "revision": anta_command.revision} + if anta_command.revision else {"cmd": anta_command.command} + for anta_command in anta_commands + ] + if self.enable and self._enable_password is not None: - commands.append( - { - "cmd": "enable", - "input": str(self._enable_password), - }, - ) + commands.insert(0, {"cmd": "enable", "input": str(self._enable_password)}) elif self.enable: # No password - commands.append({"cmd": "enable"}) - commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}] + commands.insert(0, {"cmd": "enable"}) + try: - response: list[dict[str, Any] | str] = await self._session.cli( + response = await self._session.cli( commands=commands, - ofmt=command.ofmt, - version=command.version, - req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}", - ) # type: ignore[assignment] # multiple commands returns a list - # Do not keep response of 'enable' command - command.output = response[-1] - except asynceapi.EapiCommandError as e: + ofmt=req_format, + req_id=f"ANTA-{req_id}", + ) + # If enable was used, exclude the first element from the response + if self.enable: + response = response[1:] + + # Populate the output attribute of the AntaCommand objects + for anta_command, command_output in zip(anta_commands, response): + anta_command.output = command_output + + except EapiCommandError as e: # This block catches exceptions related to EOS issuing an error. - command.errors = e.errors - if command.requires_privileges: - logger.error( - "Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", command.command, self.name - ) - if command.supported: - logger.error("Command '%s' failed on %s: %s", command.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors) - else: - logger.debug("Command '%s' is not supported on '%s' (%s)", command.command, self.name, self.hw_model) + await self._handle_eapi_command_error(e, anta_commands, req_format=req_format, req_id=req_id) except TimeoutException as e: - # This block catches Timeout exceptions. - command.errors = [exc_to_str(e)] - timeouts = self._session.timeout.as_dict() - logger.error( - "%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s", - exc_to_str(e), - self.name, - timeouts["connect"], - timeouts["read"], - timeouts["write"], - timeouts["pool"], - ) - except (ConnectError, OSError) as e: - # This block catches OSError and socket issues related exceptions. - command.errors = [exc_to_str(e)] - if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := e, OSError): # pylint: disable=no-member - if isinstance(os_error.__cause__, OSError): - os_error = os_error.__cause__ - logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error) - else: - anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger) + # This block catches exceptions related to the timeout of the request. + self._handle_timeout_exception(e, anta_commands) + except ConnectError as e: + # This block catches exceptions related to the connection to the device. + self._handle_connect_os_error(e, anta_commands) except HTTPError as e: - # This block catches most of the httpx Exceptions and logs a general message. - command.errors = [exc_to_str(e)] - anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger) - logger.debug("%s: %s", self.name, command) + # This block catches exceptions related to the HTTP connection. + self._handle_http_error(e, anta_commands) async def refresh(self) -> None: """Update attributes of an AsyncEOSDevice instance. @@ -389,8 +468,8 @@ async def refresh(self) -> None: logger.debug("Refreshing device %s", self.name) self.is_online = await self._session.check_connection() if self.is_online: - show_version = AntaCommand(command="show version") - await self._collect(show_version) + show_version = AntaCommand(command="show version", revision=1) + await self._collect([show_version], req_format="json", req_id="Refresh") if not show_version.collected: logger.warning("Cannot get hardware information from device %s", self.name) else: @@ -441,3 +520,81 @@ async def copy(self, sources: list[Path], destination: Path, direction: Literal[ return await asyncssh.scp(src, dst) + +class EapiRequestManager: + """""" + def __init__(self, device: AntaDevice, anta_test_definitions: set[AntaTestDefinition]) -> None: + """ + Initialize the EAPIRequestManager object. + + Parameters: + - device: AntaDevice instance + - anta_test_definitions: Set of AntaTestDefinition instances to be prepared + """ + self.device = device + self.anta_test_definitions = anta_test_definitions + self.requests = {} + self.commands_per_request = {} + self.current_batch = [] + self.current_batch_commands = [] + self.current_batch_size = 0 + + def get_commands(self, req_id: str) -> list[AntaCommand]: + """Get the list of AntaCommand for the specified request ID.""" + if req_id not in self.commands_per_request: + msg = f"Request ID {req_id} not found in the commands per request mapping." + raise ValueError(msg) + + return self.commands_per_request[req_id] + + def generate_request_id(self) -> str: + """Generate a unique request ID using the device name and a UUID.""" + return str(uuid4()) + + def add_new_request(self): + """Add the current batch as a new request and reset the batch.""" + request_id = self.generate_request_id() + self.requests[request_id] = self.current_batch + self.commands_per_request[request_id] = self.current_batch_commands + + # Reset the current batch and its attributes + self.current_batch = [] + self.current_batch_commands = [] + self.current_batch_size = 0 + + def build_requests(self, batch_size: int): + """Prepare the requests based on the selected tests and batch size.""" + for anta_test_definition in self.anta_test_definitions: + try: + # Instantiate the test class to build the instance commands + test_instance = anta_test_definition.test(device=self.device, inputs=anta_test_definition.inputs) + except Exception as e: # pylint: disable=broad-exception-caught + # Since an AntaTest instance is potentially user-defined code, we need to catch all exceptions + # and exit gracefully with an error message. + message = "\n".join( + [ + f"There is an error when creating test {anta_test_definition.test.module}.{anta_test_definition.test.__name__}.", + f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", + ], + ) + anta_log_exception(e, message, logger) + continue + + # Don't add blocked tests to the batch + if test_instance.blocked: + continue + + num_commands = len(test_instance.instance_commands) + + # If adding this test instance exceeds the batch size, start a new request + if self.current_batch_size + num_commands > batch_size: + self.add_new_request() + + # Add the test instance and its commands to the current batch + self.current_batch.append(test_instance) + self.current_batch_commands.extend(test_instance.instance_commands) + self.current_batch_size += num_commands + + # Add the last batch + if self.current_batch: + self.add_new_request() \ No newline at end of file diff --git a/anta/models.py b/anta/models.py index c44f7e8b4..ce112e7fd 100644 --- a/anta/models.py +++ b/anta/models.py @@ -9,6 +9,7 @@ import logging import re from abc import ABC, abstractmethod +from enum import Enum from functools import wraps from string import Formatter from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar @@ -124,6 +125,16 @@ def render(self, **params: str | int | bool) -> AntaCommand: use_cache=self.use_cache, ) +class CommandWeight(Enum): + """Enum to define the weight of a command. + + The weight of a command is used to specify the computational resources + and time required to execute the command on EOS. + """ + + LIGHT = "light" + MEDIUM = "medium" + HEAVY = "heavy" class AntaCommand(BaseModel): """Class to define a command. @@ -164,6 +175,7 @@ class AntaCommand(BaseModel): errors: list[str] = [] params: AntaParamsBaseModel = AntaParamsBaseModel() use_cache: bool = True + weight: CommandWeight = CommandWeight.LIGHT @property def uid(self) -> str: @@ -523,18 +535,6 @@ def blocked(self) -> bool: state = True return state - async def collect(self) -> None: - """Collect outputs of all commands of this test class from the device of this test instance.""" - try: - if self.blocked is False: - await self.device.collect_commands(self.instance_commands, collection_id=self.name) - except Exception as e: # pylint: disable=broad-exception-caught - # device._collect() is user-defined code. - # We need to catch everything if we want the AntaTest object - # to live until the reporting - message = f"Exception raised while collecting commands for test {self.name} (on device {self.device.name})" - anta_log_exception(e, message, self.logger) - self.result.is_error(message=exc_to_str(e)) @staticmethod def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: @@ -549,7 +549,7 @@ def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: """ @wraps(function) - async def wrapper( + def wrapper( self: AntaTest, eos_data: list[dict[Any, Any] | str] | None = None, **kwargs: dict[str, Any], @@ -578,7 +578,6 @@ async def wrapper( # If some data is missing, try to collect if not self.collected: - await self.collect() if self.result.result != "unset": AntaTest.update_progress() return self.result diff --git a/anta/runner.py b/anta/runner.py index 75391da8d..2ca66366e 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -10,10 +10,10 @@ import os import resource from collections import defaultdict +from itertools import chain from typing import TYPE_CHECKING, Any -from anta import GITHUB_SUGGESTION -from anta.logger import anta_log_exception, exc_to_str +from anta.logger import exc_to_str from anta.models import AntaTest from anta.tools import Catchtime, cprofile @@ -170,23 +170,10 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio """ coros = [] for device, test_definitions in selected_tests.items(): - for test in test_definitions: - try: - test_instance = test.test(device=device, inputs=test.inputs) - coros.append(test_instance.test()) - except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught - # An AntaTest instance is potentially user-defined code. - # We need to catch everything and exit gracefully with an error message. - message = "\n".join( - [ - f"There is an error when creating test {test.test.module}.{test.test.__name__}.", - f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", - ], - ) - anta_log_exception(e, message, logger) + request_manager = device.create_eapi_request_manager(test_definitions, batch_size=50) + coros.extend(device.run(request_manager, req_id=req_id) for req_id in request_manager.requests) return coros - @cprofile() async def main( # noqa: PLR0913 manager: ResultManager, @@ -260,11 +247,11 @@ async def main( # noqa: PLR0913 return if AntaTest.progress is not None: - AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coroutines)) + AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=catalog.final_tests_count) with Catchtime(logger=logger, message="Running ANTA tests"): test_results = await asyncio.gather(*coroutines) - for r in test_results: - manager.add(r) + for result in chain.from_iterable(test_results): + manager.add(result) log_cache_statistics(selected_inventory.devices) diff --git a/asynceapi/device.py b/asynceapi/device.py index 04ec3ab7c..cef55e636 100644 --- a/asynceapi/device.py +++ b/asynceapi/device.py @@ -275,6 +275,7 @@ async def jsonrpc_exec(self, jsonrpc: dict[str, Any]) -> list[dict[str, Any] | s raise EapiCommandError( passed=[get_output(cmd_data[cmd_i]) for cmd_i, cmd in enumerate(commands[:err_at])], failed=commands[err_at]["cmd"], + err_at=err_at, errors=cmd_data[err_at]["errors"], errmsg=err_msg, not_exec=commands[err_at + 1 :], diff --git a/asynceapi/errors.py b/asynceapi/errors.py index 614427a1a..d16a959cd 100644 --- a/asynceapi/errors.py +++ b/asynceapi/errors.py @@ -19,15 +19,17 @@ class EapiCommandError(RuntimeError): ---------- failed: the failed command errmsg: a description of the failure reason + err_at: the index of the command that failed errors: the command failure details passed: a list of command results of the commands that passed not_exec: a list of commands that were not executed """ - def __init__(self, failed: str, errors: list[str], errmsg: str, passed: list[str | dict[str, Any]], not_exec: list[dict[str, Any]]) -> None: # noqa: PLR0913 # pylint: disable=too-many-arguments + def __init__(self, failed: str, err_at: int, errors: list[str], errmsg: str, passed: list[str | dict[str, Any]], not_exec: list[dict[str, Any]]) -> None: # noqa: PLR0913 # pylint: disable=too-many-arguments """Initialize for the EapiCommandError exception.""" self.failed = failed self.errmsg = errmsg + self.err_at = err_at self.errors = errors self.passed = passed self.not_exec = not_exec From 3970de6d574e57a88945b70614621c7635154389 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Wed, 12 Jun 2024 12:40:42 -0400 Subject: [PATCH 02/12] Still WIP --- anta/device.py | 11 ++++++++--- anta/runner.py | 5 +++-- asynceapi/device.py | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/anta/device.py b/anta/device.py index 5a048701d..1f10908c5 100644 --- a/anta/device.py +++ b/anta/device.py @@ -18,7 +18,7 @@ from aiocache.plugins import HitMissRatioPlugin from asynceapi import Device, EapiCommandError from asyncssh import SSHClientConnection, SSHClientConnectionOptions -from httpx import ConnectError, HTTPError, TimeoutException +from httpx import ConnectError, HTTPError, Limits, TimeoutException from anta import __DEBUG__, GITHUB_SUGGESTION from anta.logger import anta_log_exception, exc_to_str @@ -152,6 +152,7 @@ async def run(self, request_manager: EapiRequestManager, *, req_id: str): # Collect the command outputs from the device anta_commands = request_manager.get_commands(req_id) try: + logger.debug("Collecting request ID: %s", req_id) await self.collect_commands(anta_commands, req_format="json", req_id=req_id) except Exception as e: # pylint: disable=broad-exception-caught # Since device._collect() is potentially user-defined code, we need to catch all exceptions @@ -162,7 +163,9 @@ async def run(self, request_manager: EapiRequestManager, *, req_id: str): impacted_test.result.is_error(message=exc_to_str(e)) # Once all the command outputs from a request have been collected, run the validation tests - return self.validate_commands(request_manager, req_id=req_id) + results = self.validate_commands(request_manager, req_id=req_id) + logger.debug("Finished validation commands of request ID: %s", req_id) + return results def validate_commands(self, request_manager: EapiRequestManager, *, req_id: str) -> list[TestResult]: """""" @@ -180,8 +183,10 @@ async def collect_commands(self, anta_commands: list[AntaCommand], *, req_format commands: The commands to collect. collection_id: An identifier used to build the eAPI request ID. """ + # TODO: Avoid querying the cache for the initial commands that are not cached. commands_to_collect = [] + # TODO: Don't loop over commands if the cache is disabled for command in anta_commands: if self.cache is not None and self.cache_locks is not None and command.use_cache: async with self.cache_locks[command.uid]: @@ -300,7 +305,7 @@ def __init__( raise ValueError(message) self.enable = enable self._enable_password = enable_password - self._session: Device = Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout) + self._session: Device = Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout, limits=Limits(max_connections=7)) ssh_params: dict[str, Any] = {} if insecure: ssh_params["known_hosts"] = None diff --git a/anta/runner.py b/anta/runner.py index 2ca66366e..4cf0d3f38 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -221,10 +221,13 @@ async def main( # noqa: PLR0913 if selected_tests is None: return + coroutines = get_coroutines(selected_tests) + run_info = ( "--- ANTA NRFU Run Information ---\n" f"Number of devices: {len(inventory)} ({len(selected_inventory)} established)\n" f"Total number of selected tests: {catalog.final_tests_count}\n" + f"Total number of coroutines: {len(coroutines)}\n" f"Maximum number of open file descriptors for the current ANTA process: {limits[0]}\n" "---------------------------------" ) @@ -238,8 +241,6 @@ async def main( # noqa: PLR0913 "Please consult the ANTA FAQ." ) - coroutines = get_coroutines(selected_tests) - if dry_run: logger.info("Dry-run mode, exiting before running the tests.") for coro in coroutines: diff --git a/asynceapi/device.py b/asynceapi/device.py index cef55e636..458686487 100644 --- a/asynceapi/device.py +++ b/asynceapi/device.py @@ -244,7 +244,7 @@ async def jsonrpc_exec(self, jsonrpc: dict[str, Any]) -> list[dict[str, Any] | s commands = jsonrpc["params"]["cmds"] ofmt = jsonrpc["params"]["format"] - get_output = (lambda _r: _r["output"]) if ofmt == "text" else (lambda _r: _r) + get_output = (lambda _r: _r["output"] if ofmt == "text" else (json.loads(_r) if isinstance(_r, str) else _r)) # if there are no errors then return the list of command results. if (err_data := body.get("error")) is None: From 090c73042bb37f91f68dbf091a3b015794668658 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Fri, 14 Jun 2024 07:42:43 -0400 Subject: [PATCH 03/12] Using asyncio.Event --- anta/decorators.py | 4 +- anta/device.py | 189 ++++++++++++++++++-------------------------- anta/models.py | 23 +++++- anta/runner.py | 26 ++++-- asynceapi/device.py | 7 +- 5 files changed, 128 insertions(+), 121 deletions(-) diff --git a/anta/decorators.py b/anta/decorators.py index fd2a7dfae..dc57e13ec 100644 --- a/anta/decorators.py +++ b/anta/decorators.py @@ -88,7 +88,7 @@ def decorator(function: F) -> F: """ @wraps(function) - def wrapper(*args: Any, **kwargs: Any) -> TestResult: + async def wrapper(*args: Any, **kwargs: Any) -> TestResult: """Check the device's hardware model and conditionally run or skip the test. This wrapper inspects the hardware model of the device the test is run on. @@ -105,7 +105,7 @@ def wrapper(*args: Any, **kwargs: Any) -> TestResult: AntaTest.update_progress() return anta_test.result - return function(*args, **kwargs) + return await function(*args, **kwargs) return cast(F, wrapper) diff --git a/anta/device.py b/anta/device.py index 1f10908c5..062649754 100644 --- a/anta/device.py +++ b/anta/device.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import logging from abc import ABC, abstractmethod from asyncio import Lock @@ -20,7 +21,7 @@ from asyncssh import SSHClientConnection, SSHClientConnectionOptions from httpx import ConnectError, HTTPError, Limits, TimeoutException -from anta import __DEBUG__, GITHUB_SUGGESTION +from anta import __DEBUG__ from anta.logger import anta_log_exception, exc_to_str from anta.models import AntaCommand @@ -28,9 +29,6 @@ from collections.abc import Iterator from pathlib import Path - from anta.catalog import AntaTestDefinition - from anta.result_manager.models import TestResult - logger = logging.getLogger(__name__) # Do not load the default keypairs multiple times due to a performance issue introduced in cryptography 37.0 @@ -140,41 +138,6 @@ async def _collect(self, anta_commands: list[AntaCommand], *, req_format: Litera collection_id: An identifier used to build the eAPI request ID. """ - def create_eapi_request_manager(self, test_definitions: set[AntaTestDefinition], *, batch_size: int) -> EapiRequestManager: - """""" - request_manager = EapiRequestManager(self, test_definitions) - request_manager.build_requests(batch_size) - - return request_manager - - async def run(self, request_manager: EapiRequestManager, *, req_id: str): - """""" - # Collect the command outputs from the device - anta_commands = request_manager.get_commands(req_id) - try: - logger.debug("Collecting request ID: %s", req_id) - await self.collect_commands(anta_commands, req_format="json", req_id=req_id) - except Exception as e: # pylint: disable=broad-exception-caught - # Since device._collect() is potentially user-defined code, we need to catch all exceptions - # and report the errors for every impacted test in the request. - message = f"Exception raised while collecting commands on device {self.name}" - anta_log_exception(e, message, logger) - for impacted_test in request_manager.requests[req_id]: - impacted_test.result.is_error(message=exc_to_str(e)) - - # Once all the command outputs from a request have been collected, run the validation tests - results = self.validate_commands(request_manager, req_id=req_id) - logger.debug("Finished validation commands of request ID: %s", req_id) - return results - - def validate_commands(self, request_manager: EapiRequestManager, *, req_id: str) -> list[TestResult]: - """""" - - test_instances = request_manager.requests[req_id] - # Each test() method of an AntaTest instance handles exceptions and return a TestResult object - return [test_instance.test() for test_instance in test_instances] - - async def collect_commands(self, anta_commands: list[AntaCommand], *, req_format: Literal["text", "json"] = "json", req_id: str) -> None: """Collect multiple commands. @@ -183,10 +146,10 @@ async def collect_commands(self, anta_commands: list[AntaCommand], *, req_format commands: The commands to collect. collection_id: An identifier used to build the eAPI request ID. """ - # TODO: Avoid querying the cache for the initial commands that are not cached. + # FIXME: Avoid querying the cache for the initial commands that are not cached. commands_to_collect = [] - # TODO: Don't loop over commands if the cache is disabled + # FIXME: Don't loop over commands if the cache is disabled for command in anta_commands: if self.cache is not None and self.cache_locks is not None and command.use_cache: async with self.cache_locks[command.uid]: @@ -374,7 +337,7 @@ async def _handle_eapi_command_error(self, exception: EapiCommandError, anta_com def _handle_timeout_exception(self, exception: TimeoutException, anta_commands: list[AntaCommand]) -> None: """Handle TimeoutException exceptions.""" - # Populate the errors attribute of all the AntaCommand objects of the request since it failed + # FIXME: Handle timeouts more gracefully for anta_command in anta_commands: anta_command.errors = [exc_to_str(exception)] @@ -391,7 +354,7 @@ def _handle_timeout_exception(self, exception: TimeoutException, anta_commands: def _handle_connect_os_error(self, exception: ConnectError | OSError, anta_commands: list[AntaCommand]) -> None: """Handle HTTPX ConnectError and OSError exceptions.""" - # Populate the errors attribute of all the AntaCommand objects of the request since it failed + # FIXME: Handle connection errors more gracefully for anta_command in anta_commands: anta_command.errors = [exc_to_str(exception)] @@ -404,7 +367,7 @@ def _handle_connect_os_error(self, exception: ConnectError | OSError, anta_comma def _handle_http_error(self, exception: HTTPError, anta_commands: list[AntaCommand]) -> None: """Handle HTTPError exceptions.""" - # Populate the errors attribute of all the AntaCommand objects of the request since it failed + # FIXME: Handle HTTP errors more gracefully for anta_command in anta_commands: anta_command.errors = [exc_to_str(exception)] @@ -422,7 +385,6 @@ async def _collect(self, anta_commands: list[AntaCommand], *, req_format: Litera command: The command to collect. collection_id: An identifier used to build the eAPI request ID. """ - # NOTE: `asynceapi` EapiCommandError exception only supports complex commands (dict) and not simple commands (str) commands = [ {"cmd": anta_command.command, "revision": anta_command.revision} if anta_command.revision else {"cmd": anta_command.command} @@ -526,80 +488,87 @@ async def copy(self, sources: list[Path], destination: Path, direction: Literal[ return await asyncssh.scp(src, dst) -class EapiRequestManager: - """""" - def __init__(self, device: AntaDevice, anta_test_definitions: set[AntaTestDefinition]) -> None: +class RequestManager: + """Request Manager class to handle sending requests to a device. + + # FIXME: Handle text output format + # FIXME: Handle the case where the last batch is less than the batch size + # FIXME: Handle different batch sizes for different tests + # TODO: Investigate if we should transform this class into an async context manager + # TODO: Investigate if asyncio.Condition is a better choice than asyncio.Event to signal request completion + """ + + def __init__(self, device: AntaDevice, batch_size: int) -> None: """ - Initialize the EAPIRequestManager object. + Initialize the RequestManager object. - Parameters: - - device: AntaDevice instance - - anta_test_definitions: Set of AntaTestDefinition instances to be prepared + Arguments: + ---------- + device: The device object to send the requests to. + batch_size: The maximum number of commands to send in a single request. """ self.device = device - self.anta_test_definitions = anta_test_definitions - self.requests = {} - self.commands_per_request = {} - self.current_batch = [] + self.batch_size = batch_size self.current_batch_commands = [] + self.current_batch_request_ids = set() + self.current_batch_size = 0 + self.lock = asyncio.Lock() + self.pending_requests: dict[str, asyncio.Event] = {} + + async def add_commands(self, commands: list[AntaCommand]) -> None: + """Add the commands to the current batch.""" + async with self.lock: + self.current_batch_commands.extend(commands) + self.current_batch_size += len(commands) + + request_id = self.generate_request_id() + self.pending_requests[request_id] = asyncio.Event() + self.current_batch_request_ids.add(request_id) + + if self.current_batch_size >= self.batch_size: + logger.debug("Current batch size (%s) exceeded the batch size limit (%s)", self.current_batch_size, self.batch_size) + # Reset the current batch and send the request + await self.send_eapi_request() + + return request_id + + async def send_eapi_request(self) -> None: + """Send the current batch as a request.""" + eapi_request_id = self.generate_request_id() + logger.debug("Sending eAPI request ID: %s", eapi_request_id) + + task = asyncio.create_task( + self.device.collect_commands(self.current_batch_commands, req_format="json", req_id=eapi_request_id), + name=f"Request ID {eapi_request_id} on {self.device.name}", + ) + task.add_done_callback(lambda t: self.on_request_complete(t, self.current_batch_request_ids)) + + # Reset the current batch and its attributes + self.current_batch_commands.clear() + self.current_batch_request_ids.clear() self.current_batch_size = 0 - def get_commands(self, req_id: str) -> list[AntaCommand]: - """Get the list of AntaCommand for the specified request ID.""" - if req_id not in self.commands_per_request: - msg = f"Request ID {req_id} not found in the commands per request mapping." - raise ValueError(msg) + def on_request_complete(self, task: asyncio.Task, request_ids: set[str]) -> None: + """Set the event when the request is complete.""" + task_name = task.get_name() + try: + if task.cancelled(): + logger.warning("%s was cancelled", task_name) + elif task.exception(): + logger.error("%s failed: %s", task_name, task.exception()) + else: + logger.debug("%s succeeded with result: %s", task_name, task.result()) + except asyncio.CancelledError: + logger.warning("%s was cancelled unexpectedly", task_name) - return self.commands_per_request[req_id] + for request_id in request_ids: + self.pending_requests[request_id].set() + del self.pending_requests[request_id] def generate_request_id(self) -> str: - """Generate a unique request ID using the device name and a UUID.""" + """Generate a unique request ID using a UUID.""" return str(uuid4()) - def add_new_request(self): - """Add the current batch as a new request and reset the batch.""" - request_id = self.generate_request_id() - self.requests[request_id] = self.current_batch - self.commands_per_request[request_id] = self.current_batch_commands - - # Reset the current batch and its attributes - self.current_batch = [] - self.current_batch_commands = [] - self.current_batch_size = 0 - - def build_requests(self, batch_size: int): - """Prepare the requests based on the selected tests and batch size.""" - for anta_test_definition in self.anta_test_definitions: - try: - # Instantiate the test class to build the instance commands - test_instance = anta_test_definition.test(device=self.device, inputs=anta_test_definition.inputs) - except Exception as e: # pylint: disable=broad-exception-caught - # Since an AntaTest instance is potentially user-defined code, we need to catch all exceptions - # and exit gracefully with an error message. - message = "\n".join( - [ - f"There is an error when creating test {anta_test_definition.test.module}.{anta_test_definition.test.__name__}.", - f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", - ], - ) - anta_log_exception(e, message, logger) - continue - - # Don't add blocked tests to the batch - if test_instance.blocked: - continue - - num_commands = len(test_instance.instance_commands) - - # If adding this test instance exceeds the batch size, start a new request - if self.current_batch_size + num_commands > batch_size: - self.add_new_request() - - # Add the test instance and its commands to the current batch - self.current_batch.append(test_instance) - self.current_batch_commands.extend(test_instance.instance_commands) - self.current_batch_size += num_commands - - # Add the last batch - if self.current_batch: - self.add_new_request() \ No newline at end of file + async def wait_for_request(self, request_id: str) -> None: + """Wait for a specific request to complete.""" + await self.pending_requests[request_id].wait() diff --git a/anta/models.py b/anta/models.py index ce112e7fd..f821d1f76 100644 --- a/anta/models.py +++ b/anta/models.py @@ -26,7 +26,7 @@ from rich.progress import Progress, TaskID - from anta.device import AntaDevice + from anta.device import AntaDevice, RequestManager F = TypeVar("F", bound=Callable[..., Any]) # Proper way to type input class - revisit this later if we get any issue @gmuloc @@ -388,6 +388,7 @@ class Filters(BaseModel): def __init__( self, device: AntaDevice, + request_manager: RequestManager, inputs: dict[str, Any] | AntaTest.Input | None = None, eos_data: list[dict[Any, Any] | str] | None = None, ) -> None: @@ -402,6 +403,7 @@ def __init__( """ self.logger: logging.Logger = logging.getLogger(f"{self.module}.{self.__class__.__name__}") self.device: AntaDevice = device + self.request_manager: RequestManager = request_manager self.inputs: AntaTest.Input self.instance_commands: list[AntaCommand] = [] self.result: TestResult = TestResult( @@ -535,6 +537,18 @@ def blocked(self) -> bool: state = True return state + async def send_commands(self) -> str: + """Collect outputs of all commands of this test class from the device of this test instance.""" + try: + if self.blocked is False: + return await self.request_manager.add_commands(self.instance_commands) + except Exception as e: # pylint: disable=broad-exception-caught + # device._collect() is user-defined code. + # We need to catch everything if we want the AntaTest object + # to live until the reporting + message = f"Exception raised while collecting commands for test {self.name} (on device {self.device.name})" + anta_log_exception(e, message, self.logger) + self.result.is_error(message=exc_to_str(e)) @staticmethod def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: @@ -549,7 +563,7 @@ def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: """ @wraps(function) - def wrapper( + async def wrapper( self: AntaTest, eos_data: list[dict[Any, Any] | str] | None = None, **kwargs: dict[str, Any], @@ -576,12 +590,15 @@ def wrapper( self.save_commands_data(eos_data) self.logger.debug("Test %s initialized with input data %s", self.name, eos_data) - # If some data is missing, try to collect + # If the commands have not been collected, send them to the request manager and wait for the results if not self.collected: + request_id = await self.send_commands() + await self.request_manager.wait_for_request(request_id) if self.result.result != "unset": AntaTest.update_progress() return self.result + logger.debug("All commands have been collected for test %s", self.name) if cmds := self.failed_commands: unsupported_commands = [f"'{c.command}' is not supported on {self.device.hw_model}" for c in cmds if not c.supported] if unsupported_commands: diff --git a/anta/runner.py b/anta/runner.py index 4cf0d3f38..11818f1b0 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -10,10 +10,11 @@ import os import resource from collections import defaultdict -from itertools import chain from typing import TYPE_CHECKING, Any -from anta.logger import exc_to_str +from anta import GITHUB_SUGGESTION +from anta.device import RequestManager +from anta.logger import anta_log_exception, exc_to_str from anta.models import AntaTest from anta.tools import Catchtime, cprofile @@ -168,10 +169,25 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio ------- The list of coroutines to run. """ + # FIXME: That could be a generator instead of a list coros = [] for device, test_definitions in selected_tests.items(): - request_manager = device.create_eapi_request_manager(test_definitions, batch_size=50) - coros.extend(device.run(request_manager, req_id=req_id) for req_id in request_manager.requests) + request_manager = RequestManager(device, batch_size=1) + for test in test_definitions: + try: + test_instance = test.test(device=device, request_manager=request_manager, inputs=test.inputs) + coros.append(test_instance.test()) + except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught + # An AntaTest instance is potentially user-defined code. + # We need to catch everything and exit gracefully with an error message. + message = "\n".join( + [ + f"There is an error when creating test {test.test.module}.{test.test.__name__}.", + f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", + ], + ) + anta_log_exception(e, message, logger) + return coros @cprofile() @@ -252,7 +268,7 @@ async def main( # noqa: PLR0913 with Catchtime(logger=logger, message="Running ANTA tests"): test_results = await asyncio.gather(*coroutines) - for result in chain.from_iterable(test_results): + for result in test_results: manager.add(result) log_cache_statistics(selected_inventory.devices) diff --git a/asynceapi/device.py b/asynceapi/device.py index 458686487..9fd887589 100644 --- a/asynceapi/device.py +++ b/asynceapi/device.py @@ -9,6 +9,7 @@ from __future__ import annotations +from json import loads from socket import getservbyname from typing import TYPE_CHECKING, Any @@ -244,7 +245,9 @@ async def jsonrpc_exec(self, jsonrpc: dict[str, Any]) -> list[dict[str, Any] | s commands = jsonrpc["params"]["cmds"] ofmt = jsonrpc["params"]["format"] - get_output = (lambda _r: _r["output"] if ofmt == "text" else (json.loads(_r) if isinstance(_r, str) else _r)) + # Return the correct output format based on the requested ofmt. + def get_output(response: dict[str, Any]) -> str | dict[str, Any]: + return response["output"] if ofmt == "text" else loads(response) if isinstance(response, str) else response # if there are no errors then return the list of command results. if (err_data := body.get("error")) is None: @@ -272,6 +275,8 @@ async def jsonrpc_exec(self, jsonrpc: dict[str, Any]) -> list[dict[str, Any] | s err_at = len_data - 1 err_msg = err_data["message"] + # FIXME: EapiCommandError exception only supports complex commands (dict) and not simple commands (str) + # https://github.com/aristanetworks/anta/issues/718 raise EapiCommandError( passed=[get_output(cmd_data[cmd_i]) for cmd_i, cmd in enumerate(commands[:err_at])], failed=commands[err_at]["cmd"], From 650337b4a43192ccbcafb3147e76d0d150787ce5 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Fri, 14 Jun 2024 16:25:26 -0400 Subject: [PATCH 04/12] Adding last batch --- anta/device.py | 38 +++++++++++++++++++++++++++++++------- anta/models.py | 2 +- anta/runner.py | 38 +++++++++++++++----------------------- 3 files changed, 47 insertions(+), 31 deletions(-) diff --git a/anta/device.py b/anta/device.py index 062649754..b4a644525 100644 --- a/anta/device.py +++ b/anta/device.py @@ -20,6 +20,7 @@ from asynceapi import Device, EapiCommandError from asyncssh import SSHClientConnection, SSHClientConnectionOptions from httpx import ConnectError, HTTPError, Limits, TimeoutException +from typing_extensions import Self from anta import __DEBUG__ from anta.logger import anta_log_exception, exc_to_str @@ -28,6 +29,8 @@ if TYPE_CHECKING: from collections.abc import Iterator from pathlib import Path + from types import TracebackType + from anta.models import AntaTest logger = logging.getLogger(__name__) @@ -514,10 +517,29 @@ def __init__(self, device: AntaDevice, batch_size: int) -> None: self.current_batch_size = 0 self.lock = asyncio.Lock() self.pending_requests: dict[str, asyncio.Event] = {} + self.test_instances = set() - async def add_commands(self, commands: list[AntaCommand]) -> None: + async def __aenter__(self) -> Self: + """Enter the async context manager.""" + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: + """Exit the async context and send any remaining commands.""" + async with self.lock: + if self.current_batch_commands: + logger.warning("Exiting RequestManager context with pending commands") + await self.send_eapi_request() + else: + logger.warning("Exiting RequestManager context with no pending commands") + + + async def add_commands(self, commands: list[AntaCommand], test_instance: AntaTest) -> None: """Add the commands to the current batch.""" async with self.lock: + # Remove the test instance from the tracking set since its commands are being processed + if test_instance in self.test_instances: + self.test_instances.remove(test_instance) + self.current_batch_commands.extend(commands) self.current_batch_size += len(commands) @@ -525,9 +547,8 @@ async def add_commands(self, commands: list[AntaCommand]) -> None: self.pending_requests[request_id] = asyncio.Event() self.current_batch_request_ids.add(request_id) - if self.current_batch_size >= self.batch_size: - logger.debug("Current batch size (%s) exceeded the batch size limit (%s)", self.current_batch_size, self.batch_size) - # Reset the current batch and send the request + # Send the request if the batch size is exceeded or there are no more test instances to process + if self.current_batch_size >= self.batch_size or not self.test_instances: await self.send_eapi_request() return request_id @@ -537,11 +558,14 @@ async def send_eapi_request(self) -> None: eapi_request_id = self.generate_request_id() logger.debug("Sending eAPI request ID: %s", eapi_request_id) + current_batch_commands = self.current_batch_commands.copy() + current_batch_request_ids = self.current_batch_request_ids.copy() + task = asyncio.create_task( - self.device.collect_commands(self.current_batch_commands, req_format="json", req_id=eapi_request_id), + self.device.collect_commands(current_batch_commands, req_format="json", req_id=eapi_request_id), name=f"Request ID {eapi_request_id} on {self.device.name}", ) - task.add_done_callback(lambda t: self.on_request_complete(t, self.current_batch_request_ids)) + task.add_done_callback(lambda t: self.on_request_complete(t, current_batch_request_ids)) # Reset the current batch and its attributes self.current_batch_commands.clear() @@ -557,7 +581,7 @@ def on_request_complete(self, task: asyncio.Task, request_ids: set[str]) -> None elif task.exception(): logger.error("%s failed: %s", task_name, task.exception()) else: - logger.debug("%s succeeded with result: %s", task_name, task.result()) + logger.debug("%s succeeded", task_name) except asyncio.CancelledError: logger.warning("%s was cancelled unexpectedly", task_name) diff --git a/anta/models.py b/anta/models.py index f821d1f76..521165878 100644 --- a/anta/models.py +++ b/anta/models.py @@ -541,7 +541,7 @@ async def send_commands(self) -> str: """Collect outputs of all commands of this test class from the device of this test instance.""" try: if self.blocked is False: - return await self.request_manager.add_commands(self.instance_commands) + return await self.request_manager.add_commands(self.instance_commands, test_instance=self) except Exception as e: # pylint: disable=broad-exception-caught # device._collect() is user-defined code. # We need to catch everything if we want the AntaTest object diff --git a/anta/runner.py b/anta/runner.py index 11818f1b0..8b6da0a57 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -157,26 +157,17 @@ def prepare_tests( return device_to_tests - -def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]]) -> list[Coroutine[Any, Any, TestResult]]: - """Get the coroutines for the ANTA run. - - Args: - ---- - selected_tests: A mapping of devices to the tests to run. The selected tests are generated by the `prepare_tests` function. - - Returns - ------- - The list of coroutines to run. - """ - # FIXME: That could be a generator instead of a list - coros = [] - for device, test_definitions in selected_tests.items(): - request_manager = RequestManager(device, batch_size=1) +async def run_device_tests(device: AntaDevice, test_definitions: set[AntaTestDefinition], batch_size: int) -> list[TestResult]: + """Run tests for a specific device using the RequestManager.""" + async with RequestManager(device, batch_size) as request_manager: + coros = [] for test in test_definitions: try: test_instance = test.test(device=device, request_manager=request_manager, inputs=test.inputs) coros.append(test_instance.test()) + + # Add the instance to the request manager to track it + request_manager.test_instances.add(test_instance) except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught # An AntaTest instance is potentially user-defined code. # We need to catch everything and exit gracefully with an error message. @@ -187,8 +178,9 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio ], ) anta_log_exception(e, message, logger) + results = await asyncio.gather(*coros) + return results - return coros @cprofile() async def main( # noqa: PLR0913 @@ -237,13 +229,13 @@ async def main( # noqa: PLR0913 if selected_tests is None: return - coroutines = get_coroutines(selected_tests) + device_coroutines = [run_device_tests(device, test_definitions, batch_size=100) for device, test_definitions in selected_tests.items()] run_info = ( "--- ANTA NRFU Run Information ---\n" f"Number of devices: {len(inventory)} ({len(selected_inventory)} established)\n" f"Total number of selected tests: {catalog.final_tests_count}\n" - f"Total number of coroutines: {len(coroutines)}\n" + f"Total number of coroutines: {len(device_coroutines)}\n" f"Maximum number of open file descriptors for the current ANTA process: {limits[0]}\n" "---------------------------------" ) @@ -259,7 +251,7 @@ async def main( # noqa: PLR0913 if dry_run: logger.info("Dry-run mode, exiting before running the tests.") - for coro in coroutines: + for coro in device_coroutines: coro.close() return @@ -267,8 +259,8 @@ async def main( # noqa: PLR0913 AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=catalog.final_tests_count) with Catchtime(logger=logger, message="Running ANTA tests"): - test_results = await asyncio.gather(*coroutines) - for result in test_results: - manager.add(result) + test_results = await asyncio.gather(*device_coroutines) + for r in test_results: + manager.add(r) log_cache_statistics(selected_inventory.devices) From d750fa6eaa0567ebdf641fbaa57fa54df0796d7c Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Sat, 15 Jun 2024 15:40:47 -0400 Subject: [PATCH 05/12] Working state with asyncio.Condition --- anta/device.py | 56 ++++++++++++++++++++++++---------------------- anta/models.py | 10 ++++++++- anta/runner.py | 60 +++++++++++++++++++++++++++----------------------- 3 files changed, 71 insertions(+), 55 deletions(-) diff --git a/anta/device.py b/anta/device.py index b4a644525..310a7f064 100644 --- a/anta/device.py +++ b/anta/device.py @@ -30,7 +30,6 @@ from collections.abc import Iterator from pathlib import Path from types import TracebackType - from anta.models import AntaTest logger = logging.getLogger(__name__) @@ -497,6 +496,7 @@ class RequestManager: # FIXME: Handle text output format # FIXME: Handle the case where the last batch is less than the batch size # FIXME: Handle different batch sizes for different tests + # FIXME: Handle the case where a single test send more than one batch # TODO: Investigate if we should transform this class into an async context manager # TODO: Investigate if asyncio.Condition is a better choice than asyncio.Event to signal request completion """ @@ -512,12 +512,15 @@ def __init__(self, device: AntaDevice, batch_size: int) -> None: """ self.device = device self.batch_size = batch_size - self.current_batch_commands = [] - self.current_batch_request_ids = set() + self.condition = asyncio.Condition() + self.completed_coroutines = 0 + self.current_batch_commands: list[AntaCommand] = [] + self.current_batch_request_ids: set[str] = set() self.current_batch_size = 0 self.lock = asyncio.Lock() self.pending_requests: dict[str, asyncio.Event] = {} - self.test_instances = set() + self.final_commands: dict[str, list[AntaCommand]] = {} + self.final_request_ids: dict[str, set[str]] = {} async def __aenter__(self) -> Self: """Enter the async context manager.""" @@ -532,14 +535,9 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseExcepti else: logger.warning("Exiting RequestManager context with no pending commands") - - async def add_commands(self, commands: list[AntaCommand], test_instance: AntaTest) -> None: + async def add_commands(self, commands: list[AntaCommand]) -> None: """Add the commands to the current batch.""" async with self.lock: - # Remove the test instance from the tracking set since its commands are being processed - if test_instance in self.test_instances: - self.test_instances.remove(test_instance) - self.current_batch_commands.extend(commands) self.current_batch_size += len(commands) @@ -547,31 +545,37 @@ async def add_commands(self, commands: list[AntaCommand], test_instance: AntaTes self.pending_requests[request_id] = asyncio.Event() self.current_batch_request_ids.add(request_id) - # Send the request if the batch size is exceeded or there are no more test instances to process - if self.current_batch_size >= self.batch_size or not self.test_instances: - await self.send_eapi_request() + # Once the batch size is reached, add it to the batches list + if self.current_batch_size >= self.batch_size: + await self.add_batch() return request_id - async def send_eapi_request(self) -> None: - """Send the current batch as a request.""" - eapi_request_id = self.generate_request_id() - logger.debug("Sending eAPI request ID: %s", eapi_request_id) - - current_batch_commands = self.current_batch_commands.copy() - current_batch_request_ids = self.current_batch_request_ids.copy() - - task = asyncio.create_task( - self.device.collect_commands(current_batch_commands, req_format="json", req_id=eapi_request_id), - name=f"Request ID {eapi_request_id} on {self.device.name}", - ) - task.add_done_callback(lambda t: self.on_request_complete(t, current_batch_request_ids)) + async def add_batch(self) -> None: + """Add the current batch to the batches list.""" + batch_id = self.generate_request_id() + self.final_commands[batch_id] = self.current_batch_commands.copy() + self.final_request_ids[batch_id] = self.current_batch_request_ids.copy() # Reset the current batch and its attributes self.current_batch_commands.clear() self.current_batch_request_ids.clear() self.current_batch_size = 0 + async def send_eapi_requests(self) -> None: + """Send all the requests from the batches mapping.""" + tasks = [] + for batch_id, commands in self.final_commands.items(): + eapi_request_id = self.generate_request_id() + + task = asyncio.create_task( + self.device.collect_commands(commands, req_format="json", req_id=eapi_request_id), + name=f"Request ID {eapi_request_id} on {self.device.name}", + ) + task.add_done_callback(lambda t, i=batch_id: self.on_request_complete(t, self.final_request_ids[i])) + tasks.append(task) + await asyncio.gather(*tasks) + def on_request_complete(self, task: asyncio.Task, request_ids: set[str]) -> None: """Set the event when the request is complete.""" task_name = task.get_name() diff --git a/anta/models.py b/anta/models.py index 521165878..78e2ca6d7 100644 --- a/anta/models.py +++ b/anta/models.py @@ -22,6 +22,7 @@ from anta.result_manager.models import TestResult if TYPE_CHECKING: + import asyncio from collections.abc import Coroutine from rich.progress import Progress, TaskID @@ -541,7 +542,7 @@ async def send_commands(self) -> str: """Collect outputs of all commands of this test class from the device of this test instance.""" try: if self.blocked is False: - return await self.request_manager.add_commands(self.instance_commands, test_instance=self) + return await self.request_manager.add_commands(self.instance_commands) except Exception as e: # pylint: disable=broad-exception-caught # device._collect() is user-defined code. # We need to catch everything if we want the AntaTest object @@ -593,6 +594,13 @@ async def wrapper( # If the commands have not been collected, send them to the request manager and wait for the results if not self.collected: request_id = await self.send_commands() + + # Signal that this test coroutine has completed sending commands + async with self.request_manager.condition: + self.request_manager.completed_coroutines += 1 + self.request_manager.condition.notify() + + # Wait for the request containing the commands to complete await self.request_manager.wait_for_request(request_id) if self.result.result != "unset": AntaTest.update_progress() diff --git a/anta/runner.py b/anta/runner.py index 8b6da0a57..d10419972 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -10,7 +10,8 @@ import os import resource from collections import defaultdict -from typing import TYPE_CHECKING, Any +from itertools import chain +from typing import TYPE_CHECKING from anta import GITHUB_SUGGESTION from anta.device import RequestManager @@ -19,8 +20,6 @@ from anta.tools import Catchtime, cprofile if TYPE_CHECKING: - from collections.abc import Coroutine - from anta.catalog import AntaCatalog, AntaTestDefinition from anta.device import AntaDevice from anta.inventory import AntaInventory @@ -159,28 +158,33 @@ def prepare_tests( async def run_device_tests(device: AntaDevice, test_definitions: set[AntaTestDefinition], batch_size: int) -> list[TestResult]: """Run tests for a specific device using the RequestManager.""" - async with RequestManager(device, batch_size) as request_manager: - coros = [] - for test in test_definitions: - try: - test_instance = test.test(device=device, request_manager=request_manager, inputs=test.inputs) - coros.append(test_instance.test()) - - # Add the instance to the request manager to track it - request_manager.test_instances.add(test_instance) - except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught - # An AntaTest instance is potentially user-defined code. - # We need to catch everything and exit gracefully with an error message. - message = "\n".join( - [ - f"There is an error when creating test {test.test.module}.{test.test.__name__}.", - f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", - ], - ) - anta_log_exception(e, message, logger) - results = await asyncio.gather(*coros) - return results + tasks = [] + request_manager = RequestManager(device=device, batch_size=batch_size) + for test in test_definitions: + try: + test_instance = test.test(device=device, request_manager=request_manager, inputs=test.inputs) + tasks.append(asyncio.create_task(test_instance.test())) + except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught + # An AntaTest instance is potentially user-defined code. + # We need to catch everything and exit gracefully with an error message. + message = "\n".join( + [ + f"There is an error when creating test {test.test.module}.{test.test.__name__}.", + f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", + ], + ) + anta_log_exception(e, message, logger) + + async with request_manager.condition: + logger.warning("Waiting for all tests to submit their commands to the Request Manager.") + await request_manager.condition.wait_for(lambda: request_manager.completed_coroutines == len(tasks)) + logger.warning("All tests for device %s have submitted their commands to the Request Manager.", device.name) + + # Tell the RequestManager to send the requests to the device + await request_manager.send_eapi_requests() + # Return the results of the tests + return [task.result() for task in tasks] @cprofile() async def main( # noqa: PLR0913 @@ -229,7 +233,7 @@ async def main( # noqa: PLR0913 if selected_tests is None: return - device_coroutines = [run_device_tests(device, test_definitions, batch_size=100) for device, test_definitions in selected_tests.items()] + device_coroutines = [run_device_tests(device, test_definitions, batch_size=1) for device, test_definitions in selected_tests.items()] run_info = ( "--- ANTA NRFU Run Information ---\n" @@ -259,8 +263,8 @@ async def main( # noqa: PLR0913 AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=catalog.final_tests_count) with Catchtime(logger=logger, message="Running ANTA tests"): - test_results = await asyncio.gather(*device_coroutines) - for r in test_results: - manager.add(r) + results = chain.from_iterable(await asyncio.gather(*device_coroutines)) + for result in results: + manager.add(result) log_cache_statistics(selected_inventory.devices) From 1d00f167ad0d7257f2ae796c293b9685e4b6aea6 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Sat, 15 Jun 2024 16:24:59 -0400 Subject: [PATCH 06/12] Added logs and last batch --- anta/device.py | 4 ++++ anta/models.py | 4 +++- anta/runner.py | 18 +++++++++++------- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/anta/device.py b/anta/device.py index 310a7f064..cba9ebe60 100644 --- a/anta/device.py +++ b/anta/device.py @@ -564,6 +564,10 @@ async def add_batch(self) -> None: async def send_eapi_requests(self) -> None: """Send all the requests from the batches mapping.""" + # Check if there are any commands left in the current batch + if self.current_batch_commands: + await self.add_batch() + tasks = [] for batch_id, commands in self.final_commands.items(): eapi_request_id = self.generate_request_id() diff --git a/anta/models.py b/anta/models.py index 78e2ca6d7..34f7f5786 100644 --- a/anta/models.py +++ b/anta/models.py @@ -593,6 +593,7 @@ async def wrapper( # If the commands have not been collected, send them to the request manager and wait for the results if not self.collected: + logger.debug("<%s>: Sending commands for test %s to the Result Manager", self.device.name, self.name) request_id = await self.send_commands() # Signal that this test coroutine has completed sending commands @@ -602,11 +603,12 @@ async def wrapper( # Wait for the request containing the commands to complete await self.request_manager.wait_for_request(request_id) + logger.debug("<%s>: All commands have been collected for test %s", self.device.name, self.name) + if self.result.result != "unset": AntaTest.update_progress() return self.result - logger.debug("All commands have been collected for test %s", self.name) if cmds := self.failed_commands: unsupported_commands = [f"'{c.command}' is not supported on {self.device.hw_model}" for c in cmds if not c.supported] if unsupported_commands: diff --git a/anta/runner.py b/anta/runner.py index d10419972..2e55449ae 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -161,29 +161,34 @@ async def run_device_tests(device: AntaDevice, test_definitions: set[AntaTestDef tasks = [] request_manager = RequestManager(device=device, batch_size=batch_size) for test in test_definitions: + full_name = f"{test.test.module}.{test.test.__name__}" try: test_instance = test.test(device=device, request_manager=request_manager, inputs=test.inputs) - tasks.append(asyncio.create_task(test_instance.test())) + task = asyncio.create_task(test_instance.test(), name=full_name) + logger.debug("Creating task: %s", task.get_name()) + tasks.append(task) except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught # An AntaTest instance is potentially user-defined code. # We need to catch everything and exit gracefully with an error message. message = "\n".join( [ - f"There is an error when creating test {test.test.module}.{test.test.__name__}.", + f"There is an error when creating test {full_name}.", f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", ], ) anta_log_exception(e, message, logger) async with request_manager.condition: - logger.warning("Waiting for all tests to submit their commands to the Request Manager.") + logger.debug("<%s>: Waiting for all tests to submit their commands to the Request Manager", device.name) await request_manager.condition.wait_for(lambda: request_manager.completed_coroutines == len(tasks)) - logger.warning("All tests for device %s have submitted their commands to the Request Manager.", device.name) + logger.debug("<%s>: All tests have submitted their commands to the Request Manager", device.name) - # Tell the RequestManager to send the requests to the device + # Tell the RequestManager to send all requests to the device + logger.debug("<%s>: Sending all eAPI requests to the device", device.name) await request_manager.send_eapi_requests() # Return the results of the tests + logger.debug("<%s>: All tests completed", device.name) return [task.result() for task in tasks] @cprofile() @@ -233,13 +238,12 @@ async def main( # noqa: PLR0913 if selected_tests is None: return - device_coroutines = [run_device_tests(device, test_definitions, batch_size=1) for device, test_definitions in selected_tests.items()] + device_coroutines = [run_device_tests(device, test_definitions, batch_size=250) for device, test_definitions in selected_tests.items()] run_info = ( "--- ANTA NRFU Run Information ---\n" f"Number of devices: {len(inventory)} ({len(selected_inventory)} established)\n" f"Total number of selected tests: {catalog.final_tests_count}\n" - f"Total number of coroutines: {len(device_coroutines)}\n" f"Maximum number of open file descriptors for the current ANTA process: {limits[0]}\n" "---------------------------------" ) From 0c9c0a0a3a6fbfaa16f8f84c80a9e79d1f89cf44 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Sat, 15 Jun 2024 17:01:38 -0400 Subject: [PATCH 07/12] Added wait_for_commands --- anta/device.py | 7 +++++++ anta/runner.py | 6 ++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/anta/device.py b/anta/device.py index cba9ebe60..d760d47b8 100644 --- a/anta/device.py +++ b/anta/device.py @@ -535,6 +535,13 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseExcepti else: logger.warning("Exiting RequestManager context with no pending commands") + async def wait_for_commands(self, total_tasks: int) -> None: + """Wait until all commands from the coroutine tests are received.""" + async with self.condition: + logger.debug("<%s>: Waiting for all tests to submit their commands to the Request Manager", self.device.name) + await self.condition.wait_for(lambda: self.completed_coroutines == total_tasks) + logger.debug("<%s>: All tests have submitted their commands to the Request Manager", self.device.name) + async def add_commands(self, commands: list[AntaCommand]) -> None: """Add the commands to the current batch.""" async with self.lock: diff --git a/anta/runner.py b/anta/runner.py index 2e55449ae..43bfd9190 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -178,10 +178,8 @@ async def run_device_tests(device: AntaDevice, test_definitions: set[AntaTestDef ) anta_log_exception(e, message, logger) - async with request_manager.condition: - logger.debug("<%s>: Waiting for all tests to submit their commands to the Request Manager", device.name) - await request_manager.condition.wait_for(lambda: request_manager.completed_coroutines == len(tasks)) - logger.debug("<%s>: All tests have submitted their commands to the Request Manager", device.name) + # Wait until all commands from all tests are sent to the RequestManager + await request_manager.wait_for_commands(total_tasks=len(tasks)) # Tell the RequestManager to send all requests to the device logger.debug("<%s>: Sending all eAPI requests to the device", device.name) From d6eff98bfb15fc35ef9fe3d6c7d1c49c2f335c77 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Sat, 15 Jun 2024 17:10:49 -0400 Subject: [PATCH 08/12] Clean-up --- anta/device.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/anta/device.py b/anta/device.py index d760d47b8..1d9e288f6 100644 --- a/anta/device.py +++ b/anta/device.py @@ -8,7 +8,6 @@ import asyncio import logging from abc import ABC, abstractmethod -from asyncio import Lock from collections import defaultdict from typing import TYPE_CHECKING, Any, Literal from uuid import uuid4 @@ -20,7 +19,6 @@ from asynceapi import Device, EapiCommandError from asyncssh import SSHClientConnection, SSHClientConnectionOptions from httpx import ConnectError, HTTPError, Limits, TimeoutException -from typing_extensions import Self from anta import __DEBUG__ from anta.logger import anta_log_exception, exc_to_str @@ -29,7 +27,6 @@ if TYPE_CHECKING: from collections.abc import Iterator from pathlib import Path - from types import TracebackType logger = logging.getLogger(__name__) @@ -74,7 +71,7 @@ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bo self.is_online: bool = False self.established: bool = False self.cache: Cache | None = None - self.cache_locks: defaultdict[str, Lock] | None = None + self.cache_locks: defaultdict[str, asyncio.Lock] | None = None # Initialize cache if not disabled if not disable_cache: @@ -96,7 +93,7 @@ def __hash__(self) -> int: def _init_cache(self) -> None: """Initialize cache for the device, can be overridden by subclasses to manipulate how it works.""" self.cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()]) - self.cache_locks = defaultdict(Lock) + self.cache_locks = defaultdict(asyncio.Lock) @property def cache_statistics(self) -> dict[str, Any] | None: @@ -494,11 +491,9 @@ class RequestManager: """Request Manager class to handle sending requests to a device. # FIXME: Handle text output format - # FIXME: Handle the case where the last batch is less than the batch size # FIXME: Handle different batch sizes for different tests # FIXME: Handle the case where a single test send more than one batch - # TODO: Investigate if we should transform this class into an async context manager - # TODO: Investigate if asyncio.Condition is a better choice than asyncio.Event to signal request completion + # FIXME: Cleanup multiple attributes """ def __init__(self, device: AntaDevice, batch_size: int) -> None: @@ -522,19 +517,6 @@ def __init__(self, device: AntaDevice, batch_size: int) -> None: self.final_commands: dict[str, list[AntaCommand]] = {} self.final_request_ids: dict[str, set[str]] = {} - async def __aenter__(self) -> Self: - """Enter the async context manager.""" - return self - - async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: - """Exit the async context and send any remaining commands.""" - async with self.lock: - if self.current_batch_commands: - logger.warning("Exiting RequestManager context with pending commands") - await self.send_eapi_request() - else: - logger.warning("Exiting RequestManager context with no pending commands") - async def wait_for_commands(self, total_tasks: int) -> None: """Wait until all commands from the coroutine tests are received.""" async with self.condition: From eaa6c7572c36603be0b33b2d0ded4697e48329c7 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Wed, 3 Jul 2024 08:46:35 -0400 Subject: [PATCH 09/12] Working state with asyncio.Queue --- anta/device.py | 108 --------------------------------------------- anta/models.py | 116 ++++++++++++++++++++++++++++++++++++++++++------- anta/runner.py | 37 ++++++++-------- 3 files changed, 119 insertions(+), 142 deletions(-) diff --git a/anta/device.py b/anta/device.py index 1d9e288f6..2564c7158 100644 --- a/anta/device.py +++ b/anta/device.py @@ -10,7 +10,6 @@ from abc import ABC, abstractmethod from collections import defaultdict from typing import TYPE_CHECKING, Any, Literal -from uuid import uuid4 import asyncssh import httpcore @@ -486,110 +485,3 @@ async def copy(self, sources: list[Path], destination: Path, direction: Literal[ return await asyncssh.scp(src, dst) - -class RequestManager: - """Request Manager class to handle sending requests to a device. - - # FIXME: Handle text output format - # FIXME: Handle different batch sizes for different tests - # FIXME: Handle the case where a single test send more than one batch - # FIXME: Cleanup multiple attributes - """ - - def __init__(self, device: AntaDevice, batch_size: int) -> None: - """ - Initialize the RequestManager object. - - Arguments: - ---------- - device: The device object to send the requests to. - batch_size: The maximum number of commands to send in a single request. - """ - self.device = device - self.batch_size = batch_size - self.condition = asyncio.Condition() - self.completed_coroutines = 0 - self.current_batch_commands: list[AntaCommand] = [] - self.current_batch_request_ids: set[str] = set() - self.current_batch_size = 0 - self.lock = asyncio.Lock() - self.pending_requests: dict[str, asyncio.Event] = {} - self.final_commands: dict[str, list[AntaCommand]] = {} - self.final_request_ids: dict[str, set[str]] = {} - - async def wait_for_commands(self, total_tasks: int) -> None: - """Wait until all commands from the coroutine tests are received.""" - async with self.condition: - logger.debug("<%s>: Waiting for all tests to submit their commands to the Request Manager", self.device.name) - await self.condition.wait_for(lambda: self.completed_coroutines == total_tasks) - logger.debug("<%s>: All tests have submitted their commands to the Request Manager", self.device.name) - - async def add_commands(self, commands: list[AntaCommand]) -> None: - """Add the commands to the current batch.""" - async with self.lock: - self.current_batch_commands.extend(commands) - self.current_batch_size += len(commands) - - request_id = self.generate_request_id() - self.pending_requests[request_id] = asyncio.Event() - self.current_batch_request_ids.add(request_id) - - # Once the batch size is reached, add it to the batches list - if self.current_batch_size >= self.batch_size: - await self.add_batch() - - return request_id - - async def add_batch(self) -> None: - """Add the current batch to the batches list.""" - batch_id = self.generate_request_id() - self.final_commands[batch_id] = self.current_batch_commands.copy() - self.final_request_ids[batch_id] = self.current_batch_request_ids.copy() - - # Reset the current batch and its attributes - self.current_batch_commands.clear() - self.current_batch_request_ids.clear() - self.current_batch_size = 0 - - async def send_eapi_requests(self) -> None: - """Send all the requests from the batches mapping.""" - # Check if there are any commands left in the current batch - if self.current_batch_commands: - await self.add_batch() - - tasks = [] - for batch_id, commands in self.final_commands.items(): - eapi_request_id = self.generate_request_id() - - task = asyncio.create_task( - self.device.collect_commands(commands, req_format="json", req_id=eapi_request_id), - name=f"Request ID {eapi_request_id} on {self.device.name}", - ) - task.add_done_callback(lambda t, i=batch_id: self.on_request_complete(t, self.final_request_ids[i])) - tasks.append(task) - await asyncio.gather(*tasks) - - def on_request_complete(self, task: asyncio.Task, request_ids: set[str]) -> None: - """Set the event when the request is complete.""" - task_name = task.get_name() - try: - if task.cancelled(): - logger.warning("%s was cancelled", task_name) - elif task.exception(): - logger.error("%s failed: %s", task_name, task.exception()) - else: - logger.debug("%s succeeded", task_name) - except asyncio.CancelledError: - logger.warning("%s was cancelled unexpectedly", task_name) - - for request_id in request_ids: - self.pending_requests[request_id].set() - del self.pending_requests[request_id] - - def generate_request_id(self) -> str: - """Generate a unique request ID using a UUID.""" - return str(uuid4()) - - async def wait_for_request(self, request_id: str) -> None: - """Wait for a specific request to complete.""" - await self.pending_requests[request_id].wait() diff --git a/anta/models.py b/anta/models.py index 34f7f5786..a50c2e168 100644 --- a/anta/models.py +++ b/anta/models.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import hashlib import logging import re @@ -27,7 +28,7 @@ from rich.progress import Progress, TaskID - from anta.device import AntaDevice, RequestManager + from anta.device import AntaDevice F = TypeVar("F", bound=Callable[..., Any]) # Proper way to type input class - revisit this later if we get any issue @gmuloc @@ -389,7 +390,7 @@ class Filters(BaseModel): def __init__( self, device: AntaDevice, - request_manager: RequestManager, + manager: AntaTestManager, inputs: dict[str, Any] | AntaTest.Input | None = None, eos_data: list[dict[Any, Any] | str] | None = None, ) -> None: @@ -404,7 +405,7 @@ def __init__( """ self.logger: logging.Logger = logging.getLogger(f"{self.module}.{self.__class__.__name__}") self.device: AntaDevice = device - self.request_manager: RequestManager = request_manager + self.manager: AntaTestManager = manager self.inputs: AntaTest.Input self.instance_commands: list[AntaCommand] = [] self.result: TestResult = TestResult( @@ -538,11 +539,11 @@ def blocked(self) -> bool: state = True return state - async def send_commands(self) -> str: + async def send_commands(self) -> asyncio.Condition: """Collect outputs of all commands of this test class from the device of this test instance.""" try: if self.blocked is False: - return await self.request_manager.add_commands(self.instance_commands) + return await self.manager.put_commands(self.instance_commands) except Exception as e: # pylint: disable=broad-exception-caught # device._collect() is user-defined code. # We need to catch everything if we want the AntaTest object @@ -593,17 +594,13 @@ async def wrapper( # If the commands have not been collected, send them to the request manager and wait for the results if not self.collected: - logger.debug("<%s>: Sending commands for test %s to the Result Manager", self.device.name, self.name) - request_id = await self.send_commands() + logger.debug("<%s>: Sending commands for test %s to the Test Manager", self.device.name, self.name) + condition = await self.send_commands() - # Signal that this test coroutine has completed sending commands - async with self.request_manager.condition: - self.request_manager.completed_coroutines += 1 - self.request_manager.condition.notify() - - # Wait for the request containing the commands to complete - await self.request_manager.wait_for_request(request_id) - logger.debug("<%s>: All commands have been collected for test %s", self.device.name, self.name) + # Wait until all commands have been collected + async with condition: + await condition.wait_for(lambda: self.collected or self.result.result != "unset") + logger.debug("<%s>: Condition has been met for test %s", self.device.name, self.name) if self.result.result != "unset": AntaTest.update_progress() @@ -662,3 +659,92 @@ def test(self) -> None: ``` """ + +class AntaTestManager: + """TODO: Add docstring. + + # FIXME: Handle text output format + # FIXME: Handle different batch sizes for different tests + """ + + def __init__(self, device: AntaDevice, batch_size: int) -> None: + """TODO: Add docstring.""" + self.device = device + self.batch_size = batch_size + self.command_queue = asyncio.Queue() + self.notif_queue = asyncio.Queue() + self.conditions: dict[str, asyncio.Condition] = {} + self.eapi_requests: set[asyncio.Task] = set() + self.current_batch_commands: list[AntaCommand] = [] + self.current_batch_id = 1 + + async def put_commands(self, commands: list[AntaCommand]) -> asyncio.Condition: + """Put commands to the command queue.""" + # TODO: Since multiple tests (coroutines) can put commands, we might need to lock this + logger.debug("Putting %d commands to the command queue", len(commands)) + await self.command_queue.put(commands) + condition = await self.notif_queue.get() + logger.debug("Condition received from the notification queue: %s", condition) + return condition + + async def get_commands(self) -> None: + """Get commands from the command queue.""" + logger.debug("Commands consumer started") + while True: + try: + get_await = self.command_queue.get() + # Wait for all tests to submit their commands + commands = await asyncio.wait_for(get_await, timeout=2.0) + logger.debug("%d commands retrieved from the queue: %s", len(commands), commands) + condition = await self.parse_commands(commands) + # TODO: Put more info (context) in the condition + await self.notif_queue.put(condition) + except asyncio.TimeoutError: # noqa: PERF203 + logger.warning("Timeout expired. Tests are done submitting commands.") + # Send the last batch + if self.current_batch_commands: + logger.debug("Sending the last batch of commands") + await self.send_eapi_request(self.current_batch_id, self.current_batch_commands.copy()) + break + except Exception: + logger.exception("An error occurred while retrieving commands from the queue.") + + async def parse_commands(self, commands: list[AntaCommand]) -> asyncio.Condition: + """Parse the commands.""" + if self.current_batch_id not in self.conditions: + self.conditions[self.current_batch_id] = asyncio.Condition() + + condition = self.conditions[self.current_batch_id] + self.current_batch_commands.extend(commands) + + # Once the batch size is reached, schedule the request + if len(self.current_batch_commands) >= self.batch_size: + logger.debug("Creating a new request task with batch ID %s", self.current_batch_id) + task = asyncio.create_task(self.send_eapi_request(self.current_batch_id, self.current_batch_commands.copy())) + self.eapi_requests.add(task) + task.add_done_callback(self.eapi_requests.discard) + + # Increment the batch ID and reset commands for the next batch + self.current_batch_id += 1 + self.current_batch_commands.clear() + + return condition + + async def send_eapi_request(self, batch_id: int, commands: list[AntaCommand]) -> None: + """Send all the requests from the batches mapping.""" + eapi_request_id = f"Batch #{batch_id}" + + logger.debug("Sending eAPI requests for batch %s with commands: %s", batch_id, commands) + task = asyncio.create_task( + self.device.collect_commands(commands, req_format="json", req_id=eapi_request_id), + name=f"{eapi_request_id} on {self.device.name}", + ) + task.add_done_callback(lambda _t: asyncio.create_task(self.on_request_complete(batch_id))) + + async def on_request_complete(self, batch_id: int) -> None: + """TODO: Add docstring.""" + # Notify the tests that the request is complete + condition: asyncio.Condition = self.conditions[batch_id] + async with condition: + logger.debug("Notifying tests that the batch %s is complete. Condition: %s", batch_id, condition) + condition.notify_all() diff --git a/anta/runner.py b/anta/runner.py index 43bfd9190..3bbe8e440 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -14,9 +14,8 @@ from typing import TYPE_CHECKING from anta import GITHUB_SUGGESTION -from anta.device import RequestManager from anta.logger import anta_log_exception, exc_to_str -from anta.models import AntaTest +from anta.models import AntaTest, AntaTestManager from anta.tools import Catchtime, cprofile if TYPE_CHECKING: @@ -158,36 +157,36 @@ def prepare_tests( async def run_device_tests(device: AntaDevice, test_definitions: set[AntaTestDefinition], batch_size: int) -> list[TestResult]: """Run tests for a specific device using the RequestManager.""" - tasks = [] - request_manager = RequestManager(device=device, batch_size=batch_size) + manager = AntaTestManager(device=device, batch_size=batch_size) + background_tasks = set() + coros = [] for test in test_definitions: - full_name = f"{test.test.module}.{test.test.__name__}" try: - test_instance = test.test(device=device, request_manager=request_manager, inputs=test.inputs) - task = asyncio.create_task(test_instance.test(), name=full_name) - logger.debug("Creating task: %s", task.get_name()) - tasks.append(task) + test_instance = test.test(device=device, manager=manager, inputs=test.inputs) + coros.append(test_instance.test()) except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught # An AntaTest instance is potentially user-defined code. # We need to catch everything and exit gracefully with an error message. message = "\n".join( [ - f"There is an error when creating test {full_name}.", + f"There is an error when creating test {test.test.module}.{test.test.__name__}.", f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", ], ) anta_log_exception(e, message, logger) - # Wait until all commands from all tests are sent to the RequestManager - await request_manager.wait_for_commands(total_tasks=len(tasks)) + # Start the command consumer + consumer_task = asyncio.create_task(manager.get_commands()) + background_tasks.add(consumer_task) + consumer_task.add_done_callback(background_tasks.discard) - # Tell the RequestManager to send all requests to the device - logger.debug("<%s>: Sending all eAPI requests to the device", device.name) - await request_manager.send_eapi_requests() + # Launch all the tests and return the results + results = await asyncio.gather(*coros) + + logger.debug("All results for %s have been collected", device.name) + + return results - # Return the results of the tests - logger.debug("<%s>: All tests completed", device.name) - return [task.result() for task in tasks] @cprofile() async def main( # noqa: PLR0913 @@ -236,7 +235,7 @@ async def main( # noqa: PLR0913 if selected_tests is None: return - device_coroutines = [run_device_tests(device, test_definitions, batch_size=250) for device, test_definitions in selected_tests.items()] + device_coroutines = [run_device_tests(device, test_definitions, batch_size=100) for device, test_definitions in selected_tests.items()] run_info = ( "--- ANTA NRFU Run Information ---\n" From a65084646d828c4753f3e6be6821c81752cf5c1e Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Wed, 3 Jul 2024 08:57:19 -0400 Subject: [PATCH 10/12] Clean-up --- anta/catalog.py | 1 + anta/device.py | 1 + anta/models.py | 17 +++++++++-------- anta/runner.py | 2 +- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/anta/catalog.py b/anta/catalog.py index 7782d9f83..3e4254d4d 100644 --- a/anta/catalog.py +++ b/anta/catalog.py @@ -373,6 +373,7 @@ def from_list(data: ListAntaTestTuples) -> AntaCatalog: raise return AntaCatalog(tests) + # TODO: Move this change to a separate PR @staticmethod def merge(catalogs: list[AntaCatalog]) -> AntaCatalog: """Merge multiple AntaCatalog instances. diff --git a/anta/device.py b/anta/device.py index 2564c7158..c2385ff09 100644 --- a/anta/device.py +++ b/anta/device.py @@ -266,6 +266,7 @@ def __init__( raise ValueError(message) self.enable = enable self._enable_password = enable_password + # TODO: Move the max_connections setting change to a separate PR self._session: Device = Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout, limits=Limits(max_connections=7)) ssh_params: dict[str, Any] = {} if insecure: diff --git a/anta/models.py b/anta/models.py index a50c2e168..956260ad5 100644 --- a/anta/models.py +++ b/anta/models.py @@ -592,12 +592,12 @@ async def wrapper( self.save_commands_data(eos_data) self.logger.debug("Test %s initialized with input data %s", self.name, eos_data) - # If the commands have not been collected, send them to the request manager and wait for the results + # If the commands have not been collected, send them to the test manager if not self.collected: - logger.debug("<%s>: Sending commands for test %s to the Test Manager", self.device.name, self.name) + logger.debug("<%s>: Sending commands for test %s to the test manager", self.device.name, self.name) condition = await self.send_commands() - # Wait until all commands have been collected + # Grab the condition returned from the manager and wait until all commands have been collected async with condition: await condition.wait_for(lambda: self.collected or self.result.result != "unset") logger.debug("<%s>: Condition has been met for test %s", self.device.name, self.name) @@ -697,11 +697,12 @@ async def get_commands(self) -> None: commands = await asyncio.wait_for(get_await, timeout=2.0) logger.debug("%d commands retrieved from the queue: %s", len(commands), commands) condition = await self.parse_commands(commands) - # TODO: Put more info (context) in the condition + # TODO: Put more info (context) in the condition for logging await self.notif_queue.put(condition) except asyncio.TimeoutError: # noqa: PERF203 - logger.warning("Timeout expired. Tests are done submitting commands.") - # Send the last batch + logger.debug("Timeout expired. Tests are done submitting commands.") + + # Send the last batch if there are any commands left if self.current_batch_commands: logger.debug("Sending the last batch of commands") await self.send_eapi_request(self.current_batch_id, self.current_batch_commands.copy()) @@ -731,7 +732,7 @@ async def parse_commands(self, commands: list[AntaCommand]) -> asyncio.Condition return condition async def send_eapi_request(self, batch_id: int, commands: list[AntaCommand]) -> None: - """Send all the requests from the batches mapping.""" + """Send an eAPI request.""" eapi_request_id = f"Batch #{batch_id}" logger.debug("Sending eAPI requests for batch %s with commands: %s", batch_id, commands) @@ -743,7 +744,7 @@ async def send_eapi_request(self, batch_id: int, commands: list[AntaCommand]) -> async def on_request_complete(self, batch_id: int) -> None: """TODO: Add docstring.""" - # Notify the tests that the request is complete + # Notify the tests that the request is complete. Multiple tests can be waiting on the same batch (condition) condition: asyncio.Condition = self.conditions[batch_id] async with condition: logger.debug("Notifying tests that the batch %s is complete. Condition: %s", batch_id, condition) diff --git a/anta/runner.py b/anta/runner.py index 3bbe8e440..6a5426e6e 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -156,7 +156,7 @@ def prepare_tests( return device_to_tests async def run_device_tests(device: AntaDevice, test_definitions: set[AntaTestDefinition], batch_size: int) -> list[TestResult]: - """Run tests for a specific device using the RequestManager.""" + """Run tests for a specific device using the AntaTestManager.""" manager = AntaTestManager(device=device, batch_size=batch_size) background_tasks = set() coros = [] From 0ba16e5addb17e6c64e88281ef6bbcd32333a6d0 Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Wed, 14 Aug 2024 09:29:54 -0400 Subject: [PATCH 11/12] WIP asyncio.Event --- anta/catalog.py | 14 ++- anta/models.py | 229 ++++++++++++++++++++++-------------------------- anta/runner.py | 49 +++-------- 3 files changed, 119 insertions(+), 173 deletions(-) diff --git a/anta/catalog.py b/anta/catalog.py index a46269462..30bd34066 100644 --- a/anta/catalog.py +++ b/anta/catalog.py @@ -10,7 +10,6 @@ import math from collections import defaultdict from inspect import isclass -from itertools import chain from json import load as json_load from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Optional, Union @@ -387,21 +386,18 @@ def from_list(data: ListAntaTestTuples) -> AntaCatalog: raise return AntaCatalog(tests) - # TODO: Move this change to a separate PR - @staticmethod - def merge(catalogs: list[AntaCatalog]) -> AntaCatalog: - """Merge multiple AntaCatalog instances. + def merge(self, catalog: AntaCatalog) -> AntaCatalog: + """Merge two AntaCatalog instances. Parameters ---------- - catalogs: List of AntaCatalog instances to merge. + catalog: AntaCatalog instance to merge to this instance. Returns ------- - A new AntaCatalog instance containing the tests of all the instances. + A new AntaCatalog instance containing the tests of the two instances. """ - combined_tests = list(chain(*(catalog.tests for catalog in catalogs))) - return AntaCatalog(tests=combined_tests) + return AntaCatalog(tests=self.tests + catalog.tests) def dump(self) -> AntaCatalogFile: """Return an AntaCatalogFile instance from this AntaCatalog instance. diff --git a/anta/models.py b/anta/models.py index 1eb3d6fa2..131c3cea5 100644 --- a/anta/models.py +++ b/anta/models.py @@ -10,8 +10,8 @@ import logging import re from abc import ABC, abstractmethod -from enum import Enum -from functools import wraps +from collections import defaultdict +from functools import cached_property, wraps from string import Formatter from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar @@ -23,11 +23,11 @@ from anta.result_manager.models import TestResult if TYPE_CHECKING: - import asyncio from collections.abc import Coroutine from rich.progress import Progress, TaskID + from anta.catalog import AntaTestDefinition from anta.device import AntaDevice F = TypeVar("F", bound=Callable[..., Any]) @@ -128,18 +128,6 @@ def render(self, **params: str | int | bool) -> AntaCommand: ) -class CommandWeight(Enum): - """Enum to define the weight of a command. - - The weight of a command is used to specify the computational resources - and time required to execute the command on EOS. - """ - - LIGHT = "light" - MEDIUM = "medium" - HEAVY = "heavy" - - class AntaCommand(BaseModel): """Class to define a command. @@ -179,14 +167,22 @@ class AntaCommand(BaseModel): errors: list[str] = [] params: AntaParamsBaseModel = AntaParamsBaseModel() use_cache: bool = True - weight: CommandWeight = CommandWeight.LIGHT - @property + # def __hash__(self) -> int: + # """Implement hashing based on the `uid` property.""" + # return hash(self.uid) + + # def __eq__(self, other: object) -> bool: + # """Implement equality based on the `uid` property.""" + # if not isinstance(other, AntaCommand): + # return False + # return self.uid == other.uid + + @cached_property def uid(self) -> str: """Generate a unique identifier for this command.""" uid_str = f"{self.command}_{self.version}_{self.revision or 'NA'}_{self.ofmt}" - # Ignoring S324 probable use of insecure hash function - sha1 is enough for our needs. - return hashlib.sha1(uid_str.encode()).hexdigest() # noqa: S324 + return hashlib.sha256(uid_str.encode()).hexdigest() @property def json_output(self) -> dict[str, Any]: @@ -318,6 +314,8 @@ def test(self) -> None: instance_commands: List of AntaCommand instances of this test result: TestResult instance representing the result of this test logger: Python logger for this test instance + event: asyncio.Event used by the AntaTestManager to signal the test that all commands have been collected + so that it can start running the validation """ # Mandatory class attributes @@ -392,7 +390,6 @@ class Filters(BaseModel): def __init__( self, device: AntaDevice, - manager: AntaTestManager, inputs: dict[str, Any] | AntaTest.Input | None = None, eos_data: list[dict[Any, Any] | str] | None = None, ) -> None: @@ -407,7 +404,6 @@ def __init__( """ self.logger: logging.Logger = logging.getLogger(f"{self.module}.{self.__class__.__name__}") self.device: AntaDevice = device - self.manager: AntaTestManager = manager self.inputs: AntaTest.Input self.instance_commands: list[AntaCommand] = [] self.result: TestResult = TestResult( @@ -416,6 +412,7 @@ def __init__( categories=self.categories, description=self.description, ) + self.event: asyncio.Event | None = None self._init_inputs(inputs) if self.result.result == "unset": self._init_commands(eos_data) @@ -541,19 +538,6 @@ def blocked(self) -> bool: state = True return state - async def send_commands(self) -> asyncio.Condition: - """Collect outputs of all commands of this test class from the device of this test instance.""" - try: - if self.blocked is False: - return await self.manager.put_commands(self.instance_commands) - except Exception as e: # pylint: disable=broad-exception-caught - # device._collect() is user-defined code. - # We need to catch everything if we want the AntaTest object - # to live until the reporting - message = f"Exception raised while collecting commands for test {self.name} (on device {self.device.name})" - anta_log_exception(e, message, self.logger) - self.result.is_error(message=exc_to_str(e)) - @staticmethod def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: """Decorate the `test()` method in child classes. @@ -570,7 +554,6 @@ def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: async def wrapper( self: AntaTest, eos_data: list[dict[Any, Any] | str] | None = None, - **kwargs: dict[str, Any], ) -> TestResult: """Inner function for the anta_test decorator. @@ -579,7 +562,6 @@ async def wrapper( self: The test instance. eos_data: Populate outputs of the test commands instead of collecting from devices. This list must have the same length and order than the `instance_commands` instance attribute. - kwargs: Any keyword argument to pass to the test. Returns ------- @@ -594,16 +576,12 @@ async def wrapper( self.save_commands_data(eos_data) self.logger.debug("Test %s initialized with input data %s", self.name, eos_data) - # If the commands have not been collected, send them to the test manager - if not self.collected: - logger.debug("<%s>: Sending commands for test %s to the test manager", self.device.name, self.name) - condition = await self.send_commands() - - # Grab the condition returned from the manager and wait until all commands have been collected - async with condition: - await condition.wait_for(lambda: self.collected or self.result.result != "unset") - logger.debug("<%s>: Condition has been met for test %s", self.device.name, self.name) + # Wait until all commands have been collected by the manager before running the test + logger.debug("Waiting for all commands to be collected for %s on device %s", self.name, self.device.name) + await self.event.wait() + self.logger.debug("Starting validation for %s on device %s", self.name, self.device.name) + if not self.collected: if self.result.result != "unset": AntaTest.update_progress() return self.result @@ -619,8 +597,9 @@ async def wrapper( AntaTest.update_progress() return self.result + # Run the test in a separate thread to avoid blocking the event loop try: - function(self, **kwargs) + await asyncio.to_thread(function, self) except Exception as e: # pylint: disable=broad-exception-caught # test() is user-defined code. # We need to catch everything if we want the AntaTest object @@ -631,6 +610,8 @@ async def wrapper( # TODO: find a correct way to time test execution AntaTest.update_progress() + + logger.debug("Validation completed for %s on device %s", self.name, self.device.name) return self.result return wrapper @@ -642,7 +623,7 @@ def update_progress(cls: type[AntaTest]) -> None: cls.progress.update(cls.nrfu_task, advance=1) @abstractmethod - def test(self) -> Coroutine[Any, Any, TestResult]: + def test(self) -> None: """Core of the test logic. This is an abstractmethod that must be implemented by child classes. @@ -666,89 +647,87 @@ def test(self) -> None: class AntaTestManager: """TODO: Add docstring. - # FIXME: Handle text output format # FIXME: Handle different batch sizes for different tests + # FIXME: Handle decorators that skip tests. For now commands are still sent to the device. """ def __init__(self, device: AntaDevice, batch_size: int) -> None: """TODO: Add docstring.""" self.device = device self.batch_size = batch_size - self.command_queue = asyncio.Queue() - self.notif_queue = asyncio.Queue() - self.conditions: dict[str, asyncio.Condition] = {} - self.eapi_requests: set[asyncio.Task] = set() - self.current_batch_commands: list[AntaCommand] = [] - self.current_batch_id = 1 - - async def put_commands(self, commands: list[AntaCommand]) -> asyncio.Condition: - """Put commands to the command queue.""" - # TODO: Since multiple tests (coroutines) can put commands, we might need to lock this - logger.debug("Putting %d commands to the command queue", len(commands)) - await self.command_queue.put(commands) - condition = await self.notif_queue.get() - logger.debug("Condition received from the notification queue: %s", condition) - return condition - - async def get_commands(self) -> None: - """Get commands from the command queue.""" - logger.debug("Commands consumer started") - while True: - try: - get_await = self.command_queue.get() - # Wait for all tests to submit their commands - commands = await asyncio.wait_for(get_await, timeout=2.0) - logger.debug("%d commands retrieved from the queue: %s", len(commands), commands) - condition = await self.parse_commands(commands) - # TODO: Put more info (context) in the condition for logging - await self.notif_queue.put(condition) - except asyncio.TimeoutError: # noqa: PERF203 - logger.debug("Timeout expired. Tests are done submitting commands.") - - # Send the last batch if there are any commands left - if self.current_batch_commands: - logger.debug("Sending the last batch of commands") - await self.send_eapi_request(self.current_batch_id, self.current_batch_commands.copy()) - break - except Exception: - logger.exception("An error occurred while retrieving commands from the queue.") - - async def parse_commands(self, commands: list[AntaCommand]) -> asyncio.Condition: - """Parse the commands.""" - if self.current_batch_id not in self.conditions: - self.conditions[self.current_batch_id] = asyncio.Condition() - - condition = self.conditions[self.current_batch_id] - self.current_batch_commands.extend(commands) - - # Once the batch size is reached, schedule the request - if len(self.current_batch_commands) >= self.batch_size: - logger.debug("Creating a new request task with batch ID %s", self.current_batch_id) - task = asyncio.create_task(self.send_eapi_request(self.current_batch_id, self.current_batch_commands.copy())) - self.eapi_requests.add(task) - task.add_done_callback(self.eapi_requests.discard) - - # Increment the batch ID and reset commands for the next batch - self.current_batch_id += 1 - self.current_batch_commands.clear() - - return condition - - async def send_eapi_request(self, batch_id: int, commands: list[AntaCommand]) -> None: - """Send an eAPI request.""" - eapi_request_id = f"Batch #{batch_id}" - - logger.debug("Sending eAPI requests for batch %s with commands: %s", batch_id, commands) - task = asyncio.create_task( - self.device.collect_commands(commands, req_format="json", req_id=eapi_request_id), - name=f"{eapi_request_id} on {self.device.name}", - ) - task.add_done_callback(lambda _t: asyncio.create_task(self.on_request_complete(batch_id))) + self.completed_command_ids: set[int] = set() + self.test_map: defaultdict[AntaTest, set[int]] = defaultdict(set) + self.events: dict[AntaTest, asyncio.Event] = {} - async def on_request_complete(self, batch_id: int) -> None: - """TODO: Add docstring.""" - # Notify the tests that the request is complete. Multiple tests can be waiting on the same batch (condition) - condition: asyncio.Condition = self.conditions[batch_id] - async with condition: - logger.debug("Notifying tests that the batch %s is complete. Condition: %s", batch_id, condition) - condition.notify_all() + async def run_tests(self, test_definitions: set[AntaTestDefinition]) -> None: + json_commands: list[AntaCommand] = [] + text_commands: list[AntaCommand] = [] + test_tasks = [] + batch_tasks = [] + + for test_definition in test_definitions: + try: + test_instance = test_definition.test(device=self.device, inputs=test_definition.inputs) + # Skip the test if it has blocked commands + if test_instance.blocked is True: + continue + + test_instance.event = asyncio.Event() + self.events[test_instance] = test_instance.event + + for command in test_instance.instance_commands: + self.test_map[test_instance].add(id(command)) + if command.ofmt == "json": + json_commands.append(command) + elif command.ofmt == "text": + text_commands.append(command) + + test_tasks.append(asyncio.create_task(test_instance.test())) + except Exception as exc: # pylint: disable=broad-exception-caught + # An AntaTest instance is potentially user-defined code. + # We need to catch everything and exit gracefully with an error message. + message = "\n".join( + [ + f"There is an error when creating test {test_definition.test.module}.{test_definition.test.__name__}.", + f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", + ], + ) + anta_log_exception(exc, message, logger) + + total_commands = len(json_commands) + len(text_commands) + logger.debug("Total commands to process for %s: %d", self.device.name, total_commands) + + # Create batches for JSON commands + json_batches = [json_commands[i : i + self.batch_size] for i in range(0, len(json_commands), self.batch_size)] + # Create batches for text commands + text_batches = [text_commands[i : i + self.batch_size] for i in range(0, len(text_commands), self.batch_size)] + + logger.debug("Number of JSON batches for %s: %d", self.device.name, len(json_batches)) + logger.debug("Number of text batches for %s: %d", self.device.name, len(text_batches)) + + # Process JSON batches + for i, batch in enumerate(json_batches): + batch_tasks.append(asyncio.create_task(self.process_batch(i, batch, "json"))) + + # Process text batches + for i, batch in enumerate(text_batches): + batch_tasks.append(asyncio.create_task(self.process_batch(i, batch, "text"))) + + # Make sure all batch tasks are completed + await asyncio.gather(*batch_tasks) + + # Make sure all test tasks are completed and return the results + return await asyncio.gather(*test_tasks) + + async def process_batch(self, batch_id: int, batch: list[AntaCommand], batch_format: Literal["text", "json"] = "json") -> None: + await self.device.collect_commands(batch, req_format=batch_format, req_id=f"Batch #{batch_id}") + for command in batch: + cmd_id = id(command) + self.completed_command_ids.add(cmd_id) + self.check_test_completion(cmd_id) + + def check_test_completion(self, cmd_id: int) -> None: + for test_instance, cmd_ids in self.test_map.items(): + if cmd_id in cmd_ids and cmd_ids.issubset(self.completed_command_ids): + logger.debug("All commands completed for %s", test_instance.name) + self.events[test_instance].set() diff --git a/anta/runner.py b/anta/runner.py index 4d3c7091c..d0af7463c 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -13,17 +13,17 @@ from itertools import chain from typing import TYPE_CHECKING -from anta import GITHUB_SUGGESTION -from anta.logger import anta_log_exception, exc_to_str +from anta.logger import exc_to_str from anta.models import AntaTest, AntaTestManager from anta.tools import Catchtime, cprofile if TYPE_CHECKING: + from collections.abc import Coroutine + from anta.catalog import AntaCatalog, AntaTestDefinition from anta.device import AntaDevice from anta.inventory import AntaInventory from anta.result_manager import ResultManager - from anta.result_manager.models import TestResult logger = logging.getLogger(__name__) @@ -156,39 +156,6 @@ def prepare_tests( return device_to_tests -async def run_device_tests(device: AntaDevice, test_definitions: set[AntaTestDefinition], batch_size: int) -> list[TestResult]: - """Run tests for a specific device using the AntaTestManager.""" - manager = AntaTestManager(device=device, batch_size=batch_size) - background_tasks = set() - coros = [] - for test in test_definitions: - try: - test_instance = test.test(device=device, manager=manager, inputs=test.inputs) - coros.append(test_instance.test()) - except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught - # An AntaTest instance is potentially user-defined code. - # We need to catch everything and exit gracefully with an error message. - message = "\n".join( - [ - f"There is an error when creating test {test.test.module}.{test.test.__name__}.", - f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", - ], - ) - anta_log_exception(e, message, logger) - - # Start the command consumer - consumer_task = asyncio.create_task(manager.get_commands()) - background_tasks.add(consumer_task) - consumer_task.add_done_callback(background_tasks.discard) - - # Launch all the tests and return the results - results = await asyncio.gather(*coros) - - logger.debug("All results for %s have been collected", device.name) - - return results - - @cprofile() async def main( # noqa: PLR0913 manager: ResultManager, @@ -236,7 +203,11 @@ async def main( # noqa: PLR0913 if selected_tests is None: return - device_coroutines = [run_device_tests(device, test_definitions, batch_size=100) for device, test_definitions in selected_tests.items()] + coros: list[Coroutine] = [] + + for device, test_definitions in selected_tests.items(): + test_manager = AntaTestManager(device=device, batch_size=100) + coros.append(test_manager.run_tests(test_definitions)) run_info = ( "--- ANTA NRFU Run Information ---\n" @@ -257,7 +228,7 @@ async def main( # noqa: PLR0913 if dry_run: logger.info("Dry-run mode, exiting before running the tests.") - for coro in device_coroutines: + for coro in coros: coro.close() return @@ -265,7 +236,7 @@ async def main( # noqa: PLR0913 AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=catalog.final_tests_count) with Catchtime(logger=logger, message="Running ANTA tests"): - results = chain.from_iterable(await asyncio.gather(*device_coroutines)) + results = chain.from_iterable(await asyncio.gather(*coros)) for result in results: manager.add(result) From 1771d72ab45a0db9ad46b6a4e878fcdc2558430f Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Wed, 14 Aug 2024 09:32:12 -0400 Subject: [PATCH 12/12] Remove commented code --- anta/models.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/anta/models.py b/anta/models.py index 131c3cea5..01857057b 100644 --- a/anta/models.py +++ b/anta/models.py @@ -168,16 +168,6 @@ class AntaCommand(BaseModel): params: AntaParamsBaseModel = AntaParamsBaseModel() use_cache: bool = True - # def __hash__(self) -> int: - # """Implement hashing based on the `uid` property.""" - # return hash(self.uid) - - # def __eq__(self, other: object) -> bool: - # """Implement equality based on the `uid` property.""" - # if not isinstance(other, AntaCommand): - # return False - # return self.uid == other.uid - @cached_property def uid(self) -> str: """Generate a unique identifier for this command."""