Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(anta): Refactor collect to send multiple commands per eAPI request #736

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
228 changes: 137 additions & 91 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from aiocache import Cache
from aiocache.plugins import HitMissRatioPlugin
from asyncssh import SSHClientConnection, SSHClientConnectionOptions
from httpx import ConnectError, HTTPError, TimeoutException
from httpx import ConnectError, HTTPError, Limits, TimeoutException

import asynceapi
from anta import __DEBUG__
from anta.logger import anta_log_exception, exc_to_str
from anta.models import AntaCommand
from asynceapi import Device, EapiCommandError

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -117,7 +117,7 @@ def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
yield "disable_cache", self.cache is None

@abstractmethod
async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
async def _collect(self, anta_commands: list[AntaCommand], *, req_format: Literal["json", "text"] = "json", req_id: str | None = None) -> None:
"""Collect device command output.

This abstract coroutine can be used to implement any command collection method
Expand All @@ -136,46 +136,41 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No
collection_id: An identifier used to build the eAPI request ID.
"""

async def collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
"""Collect the output for a specified command.

When caching is activated on both the device and the command,
this method prioritizes retrieving the output from the cache. In cases where the output isn't cached yet,
it will be freshly collected and then stored in the cache for future access.
The method employs asynchronous locks based on the command's UID to guarantee exclusive access to the cache.

When caching is NOT enabled, either at the device or command level, the method directly collects the output
via the private `_collect` method without interacting with the cache.

Parameters
----------
command: The command to collect.
collection_id: An identifier used to build the eAPI request ID.
"""
# Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
if self.cache is not None and self.cache_locks is not None and command.use_cache:
async with self.cache_locks[command.uid]:
cached_output = await self.cache.get(command.uid) # pylint: disable=no-member

if cached_output is not None:
logger.debug("Cache hit for %s on %s", command.command, self.name)
command.output = cached_output
else:
await self._collect(command=command, collection_id=collection_id)
await self.cache.set(command.uid, command.output) # pylint: disable=no-member
else:
await self._collect(command=command, collection_id=collection_id)

async def collect_commands(self, commands: list[AntaCommand], *, collection_id: str | None = None) -> None:
async def collect_commands(self, anta_commands: list[AntaCommand], *, req_format: Literal["text", "json"] = "json", req_id: str) -> None:
"""Collect multiple commands.

Parameters
----------
commands: The commands to collect.
collection_id: An identifier used to build the eAPI request ID.
"""
await asyncio.gather(*(self.collect(command=command, collection_id=collection_id) for command in commands))
# FIXME: Avoid querying the cache for the initial commands that are not cached.
commands_to_collect = []

# FIXME: Don't loop over commands if the cache is disabled
for command in anta_commands:
if self.cache is not None and self.cache_locks is not None and command.use_cache:
async with self.cache_locks[command.uid]:
# Need to disable pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
cached_output = await self.cache.get(command.uid) # pylint: disable=no-member

if cached_output is not None:
logger.debug("Cache hit for %s on %s", command.command, self.name)
command.output = cached_output
else:
commands_to_collect.append(command)
else:
commands_to_collect.append(command)

# Collect the batch of commands that are not cached
if commands_to_collect:
await self._collect(commands_to_collect, req_format=req_format, req_id=req_id)
# Cache the outputs of the collected commands
for command in commands_to_collect:
if self.cache is not None and self.cache_locks is not None and command.use_cache:
async with self.cache_locks[command.uid]:
await self.cache.set(command.uid, command.output) # pylint: disable=no-member

@abstractmethod
async def refresh(self) -> None:
Expand Down Expand Up @@ -271,7 +266,8 @@ def __init__(
raise ValueError(message)
self.enable = enable
self._enable_password = enable_password
self._session: asynceapi.Device = asynceapi.Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout)
# TODO: Move the max_connections setting change to a separate PR
self._session: Device = Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout, limits=Limits(max_connections=7))
ssh_params: dict[str, Any] = {}
if insecure:
ssh_params["known_hosts"] = None
Expand Down Expand Up @@ -306,7 +302,79 @@ def _keys(self) -> tuple[Any, ...]:
"""
return (self._session.host, self._session.port)

async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: # noqa: C901 function is too complex - because of many required except blocks #pylint: disable=line-too-long
async def _handle_eapi_command_error(self, exception: EapiCommandError, anta_commands: list[AntaCommand], *, req_format: str, req_id: str) -> None:
"""Handle EapiCommandError exceptions."""
# Populate the output attribute of the AntaCommand objects with the commands that passed
passed_outputs = exception.passed[1:] if self.enable else exception.passed
for anta_command, output in zip(anta_commands, passed_outputs):
anta_command.output = output

# Populate the errors attribute of the AntaCommand object of the command that failed
err_at = exception.err_at - 1 if self.enable else exception.err_at
anta_command = anta_commands[err_at]
anta_command.errors = exception.errors
if anta_command.requires_privileges:
logger.error(
"Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.",
anta_command.command,
self.name,
)

if anta_command.supported:
error_message = exception.errors[0] if len(exception.errors) == 1 else exception.errors
logger.error(
"Command '%s' failed on %s: %s",
anta_command.command,
self.name,
error_message,
)
else:
logger.error("Command '%s' is not supported on %s (%s).", anta_command.command, self.name, self.hw_model)

# Collect the commands that were not executed
await self._collect(anta_commands=anta_commands[err_at + 1 :], req_format=req_format, req_id=req_id)

def _handle_timeout_exception(self, exception: TimeoutException, anta_commands: list[AntaCommand]) -> None:
"""Handle TimeoutException exceptions."""
# FIXME: Handle timeouts more gracefully
for anta_command in anta_commands:
anta_command.errors = [exc_to_str(exception)]

timeouts = self._session.timeout.as_dict()
logger.error(
"%s occurred while sending commands to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s",
exc_to_str(exception),
self.name,
timeouts["connect"],
timeouts["read"],
timeouts["write"],
timeouts["pool"],
)

def _handle_connect_os_error(self, exception: ConnectError | OSError, anta_commands: list[AntaCommand]) -> None:
"""Handle HTTPX ConnectError and OSError exceptions."""
# FIXME: Handle connection errors more gracefully
for anta_command in anta_commands:
anta_command.errors = [exc_to_str(exception)]

if (isinstance(exc := exception.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(
os_error := exception, OSError
):
if isinstance(os_error.__cause__, OSError):
os_error = os_error.__cause__
logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error)
else:
anta_log_exception(exception, f"An error occurred while issuing an eAPI request to {self.name}", logger)

def _handle_http_error(self, exception: HTTPError, anta_commands: list[AntaCommand]) -> None:
"""Handle HTTPError exceptions."""
# FIXME: Handle HTTP errors more gracefully
for anta_command in anta_commands:
anta_command.errors = [exc_to_str(exception)]

anta_log_exception(exception, f"An error occurred while issuing an eAPI request to {self.name}", logger)

async def _collect(self, anta_commands: list[AntaCommand], *, req_format: Literal["json", "text"] = "json", req_id: str) -> None:
"""Collect device command output from EOS using aio-eapi.

Supports outformat `json` and `text` as output structure.
Expand All @@ -318,65 +386,43 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No
command: The command to collect.
collection_id: An identifier used to build the eAPI request ID.
"""
commands: list[dict[str, str | int]] = []
commands = [
{"cmd": anta_command.command, "revision": anta_command.revision} if anta_command.revision else {"cmd": anta_command.command}
for anta_command in anta_commands
]

if self.enable and self._enable_password is not None:
commands.append(
{
"cmd": "enable",
"input": str(self._enable_password),
},
)
commands.insert(0, {"cmd": "enable", "input": str(self._enable_password)})
elif self.enable:
# No password
commands.append({"cmd": "enable"})
commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}]
commands.insert(0, {"cmd": "enable"})

try:
response: list[dict[str, Any] | str] = await self._session.cli(
response = await self._session.cli(
commands=commands,
ofmt=command.ofmt,
version=command.version,
req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}",
) # type: ignore[assignment] # multiple commands returns a list
# Do not keep response of 'enable' command
command.output = response[-1]
except asynceapi.EapiCommandError as e:
ofmt=req_format,
req_id=f"ANTA-{req_id}",
)
# If enable was used, exclude the first element from the response
if self.enable:
response = response[1:]

# Populate the output attribute of the AntaCommand objects
for anta_command, command_output in zip(anta_commands, response):
anta_command.output = command_output

except EapiCommandError as e:
# This block catches exceptions related to EOS issuing an error.
command.errors = e.errors
if command.requires_privileges:
logger.error(
"Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", command.command, self.name
)
if command.supported:
logger.error("Command '%s' failed on %s: %s", command.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors)
else:
logger.debug("Command '%s' is not supported on '%s' (%s)", command.command, self.name, self.hw_model)
await self._handle_eapi_command_error(e, anta_commands, req_format=req_format, req_id=req_id)
except TimeoutException as e:
# This block catches Timeout exceptions.
command.errors = [exc_to_str(e)]
timeouts = self._session.timeout.as_dict()
logger.error(
"%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s",
exc_to_str(e),
self.name,
timeouts["connect"],
timeouts["read"],
timeouts["write"],
timeouts["pool"],
)
except (ConnectError, OSError) as e:
# This block catches OSError and socket issues related exceptions.
command.errors = [exc_to_str(e)]
if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := e, OSError): # pylint: disable=no-member
if isinstance(os_error.__cause__, OSError):
os_error = os_error.__cause__
logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error)
else:
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
# This block catches exceptions related to the timeout of the request.
self._handle_timeout_exception(e, anta_commands)
except ConnectError as e:
# This block catches exceptions related to the connection to the device.
self._handle_connect_os_error(e, anta_commands)
except HTTPError as e:
# This block catches most of the httpx Exceptions and logs a general message.
command.errors = [exc_to_str(e)]
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
logger.debug("%s: %s", self.name, command)
# This block catches exceptions related to the HTTP connection.
self._handle_http_error(e, anta_commands)

async def refresh(self) -> None:
"""Update attributes of an AsyncEOSDevice instance.
Expand All @@ -389,8 +435,8 @@ async def refresh(self) -> None:
logger.debug("Refreshing device %s", self.name)
self.is_online = await self._session.check_connection()
if self.is_online:
show_version = AntaCommand(command="show version")
await self._collect(show_version)
show_version = AntaCommand(command="show version", revision=1)
await self._collect([show_version], req_format="json", req_id="Refresh")
if not show_version.collected:
logger.warning("Cannot get hardware information from device %s", self.name)
else:
Expand Down
Loading
Loading