Skip to content

Commit

Permalink
refactor: Remove dependency on netaddr (#261)
Browse files Browse the repository at this point in the history
* Refactor: Remove dependency on netaddr

* Refactor: Unmangling
  • Loading branch information
gmuloc authored Jul 11, 2023
1 parent 46dc361 commit d4c4053
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 164 deletions.
2 changes: 1 addition & 1 deletion anta/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
)
@click.option("--ignore-status", show_envvar=True, is_flag=True, default=False, help="Always exit with success")
@click.option("--ignore-error", show_envvar=True, is_flag=True, default=False, help="Only report failures and not errors")
def anta(ctx: click.Context, inventory: pathlib.Path, ignore_status: bool, ignore_error: bool, **kwargs: Dict[str, Any]) -> None:
def anta(ctx: click.Context, inventory: pathlib.Path, ignore_status: bool, ignore_error: bool, **kwargs: Any) -> None:
# pylint: disable=unused-argument
"""Arista Network Test Automation (ANTA) CLI"""
ctx.ensure_object(dict)
Expand Down
13 changes: 3 additions & 10 deletions anta/cli/exec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

from aioeapi import EapiCommandError

from anta import __DEBUG__
from anta.device import AntaDevice
from anta.inventory import AntaInventory
from anta.models import AntaCommand
from anta.tools.misc import exc_to_str
from anta.tools.misc import anta_log_exception, exc_to_str

EOS_SCHEDULED_TECH_SUPPORT = "/mnt/flash/schedule/tech-support"

Expand Down Expand Up @@ -85,10 +84,7 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex
for r in res:
if isinstance(r, Exception):
message = "Error when collecting commands"
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(f"{message}: {exc_to_str(r)}")
anta_log_exception(r, message, logger)


async def collect_scheduled_show_tech(inv: AntaInventory, root_dir: Path, configure: bool, tags: Optional[List[str]] = None, latest: Optional[int] = None) -> None:
Expand Down Expand Up @@ -147,10 +143,7 @@ async def collect(device: AntaDevice) -> None:
# In this case we want to catch all exceptions
except Exception as e: # pylint: disable=broad-except
message = f"Unable to collect tech-support on device {device.name}"
if __DEBUG__:
logger.exception(message)
else:
logger.error(f"{message}: {exc_to_str(e)}")
anta_log_exception(e, message, logger)

logger.info("Connecting to devices...")
await inv.connect_inventory()
Expand Down
19 changes: 4 additions & 15 deletions anta/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from yaml import safe_load

import anta.loader
from anta import __DEBUG__
from anta.inventory import AntaInventory
from anta.result_manager.models import TestResult
from anta.tools.misc import exc_to_str
from anta.tools.misc import anta_log_exception

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,11 +58,7 @@ def parse_inventory(ctx: click.Context, path: Path) -> AntaInventory:
)
except Exception as e: # pylint: disable=broad-exception-caught
message = f"Unable to parse ANTA Inventory file '{path}'"
if __DEBUG__:
logger.exception(message)
else:
logger.error(message + f": {exc_to_str(e)}")

anta_log_exception(e, message, logger)
ctx.fail(message)
return inventory

Expand All @@ -90,10 +85,7 @@ def parse_catalog(ctx: click.Context, param: Option, value: str) -> List[Tuple[C
# pylint: disable-next=broad-exception-caught
except Exception as e:
message = f"Unable to parse ANTA Tests Catalog file '{value}'"
if __DEBUG__:
logger.exception(message)
else:
logger.error(message + f": {exc_to_str(e)}")
anta_log_exception(e, message, logger)
ctx.fail(message)

return anta.loader.parse_catalog(data)
Expand All @@ -108,10 +100,7 @@ def setup_logging(ctx: click.Context, param: Option, value: str) -> str:
anta.loader.setup_logging(value)
except Exception as e: # pylint: disable=broad-exception-caught
message = f"Unable to set ANTA logging level '{value}'"
if __DEBUG__:
logger.exception(message)
else:
logger.error(message + f": {exc_to_str(e)}")
anta_log_exception(e, message, logger)
ctx.fail(message)

return value
Expand Down
6 changes: 3 additions & 3 deletions anta/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
decorators for tests
"""
from functools import wraps
from typing import Any, Callable, Dict, List, TypeVar, cast
from typing import Any, Callable, List, TypeVar, cast

from anta.models import AntaCommand, AntaTest
from anta.result_manager.models import TestResult
Expand All @@ -28,7 +28,7 @@ def decorator(function: F) -> F:
"""

@wraps(function)
async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult:
async def wrapper(*args: Any, **kwargs: Any) -> TestResult:
"""
wrapper for func
"""
Expand Down Expand Up @@ -66,7 +66,7 @@ def decorator(function: F) -> F:
"""

@wraps(function)
async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult:
async def wrapper(*args: Any, **kwargs: Any) -> TestResult:
"""
wrapper for func
"""
Expand Down
15 changes: 7 additions & 8 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from anta import __DEBUG__
from anta.models import DEFAULT_TAG, AntaCommand
from anta.tools.misc import exc_to_str
from anta.tools.misc import anta_log_exception, exc_to_str

logger = logging.getLogger(__name__)

Expand All @@ -26,7 +26,7 @@
# Hic Sunt Draconis.
# Are we proud of this? No.
# Waiting for: https://github.com/jeremyschulman/aio-eapi/issues/9
def patched_jsoncrpc_command(self: Device, commands: List[str], ofmt: str, **kwargs: Dict[Any, Any]) -> Dict[str, Any]:
def patched_jsoncrpc_command(self: Device, commands: List[str], ofmt: str, **kwargs: Any) -> Dict[str, Any]:
"""
Used to create the JSON-RPC command dictionary object
"""
Expand Down Expand Up @@ -271,17 +271,16 @@ async def collect(self, command: AntaCommand) -> None:
logger.debug(f"{self.name}: {command}")

except EapiCommandError as e:
logger.error(f"Command '{command.command}' failed on {self.name}: {e.errmsg}")
message = f"Command '{command.command}' failed on {self.name}"
anta_log_exception(e, message, logger)
command.failed = e
except (HTTPError, ConnectError) as e:
logger.error(f"Cannot connect to device {self.name}: {exc_to_str(e)}")
message = f"Cannot connect to device {self.name}"
anta_log_exception(e, message, logger)
command.failed = e
except Exception as e: # pylint: disable=broad-exception-caught
message = f"Exception raised while collecting command '{command.command}' on device {self.name}"
if __DEBUG__:
logger.exception(message)
else:
logger.error(message + f": {exc_to_str(e)}")
anta_log_exception(e, message, logger)
command.failed = e
logger.debug(command)

Expand Down
105 changes: 81 additions & 24 deletions anta/inventory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@

import asyncio
import logging
from ipaddress import ip_address, ip_network
from typing import Any, Dict, List, Optional

from netaddr import IPAddress, IPNetwork
from pydantic import ValidationError
from yaml import safe_load

from anta import __DEBUG__
from anta.device import AntaDevice, AsyncEOSDevice
from anta.inventory.exceptions import InventoryIncorrectSchema, InventoryRootKeyError
from anta.inventory.models import AntaInventoryInput
from anta.tools.misc import exc_to_str
from anta.tools.misc import anta_log_exception

logger = logging.getLogger(__name__)

Expand All @@ -42,6 +41,81 @@ def __str__(self) -> str:
devs[dev_type] += 1
return f"ANTA Inventory contains {' '.join([f'{n} devices ({t})' for t, n in devs.items()])}"

@staticmethod
def _parse_hosts(inventory_input: AntaInventoryInput, inventory: AntaInventory, **kwargs: Any) -> None:
"""
Parses the host section of an AntaInventoryInput and add the devices to the inventory
Args:
inventory_input (AntaInventoryInput): AntaInventoryInput used to parse the devices
inventory (AntaInventory): AntaInventory to add the parsed devices to
"""
if inventory_input.hosts is None:
return

for host in inventory_input.hosts:
device = AsyncEOSDevice(name=host.name, host=str(host.host), port=host.port, tags=host.tags, **kwargs)
inventory.add_device(device)

@staticmethod
def _parse_networks(inventory_input: AntaInventoryInput, inventory: AntaInventory, **kwargs: Any) -> None:
"""
Parses the network section of an AntaInventoryInput and add the devices to the inventory.
Args:
inventory_input (AntaInventoryInput): AntaInventoryInput used to parse the devices
inventory (AntaInventory): AntaInventory to add the parsed devices to
Raises:
InventoryIncorrectSchema: Inventory file is not following AntaInventory Schema.
"""
if inventory_input.networks is None:
return

for network in inventory_input.networks:
try:
for host_ip in ip_network(str(network.network)):
device = AsyncEOSDevice(host=str(host_ip), tags=network.tags, **kwargs)
inventory.add_device(device)
except ValueError as e:
message = "Could not parse network {network.network} in the inventory"
anta_log_exception(e, message, logger)
raise InventoryIncorrectSchema(message) from e

@staticmethod
def _parse_ranges(inventory_input: AntaInventoryInput, inventory: AntaInventory, **kwargs: Any) -> None:
"""
Parses the range section of an AntaInventoryInput and add the devices to the inventory.
Args:
inventory_input (AntaInventoryInput): AntaInventoryInput used to parse the devices
inventory (AntaInventory): AntaInventory to add the parsed devices to
Raises:
InventoryIncorrectSchema: Inventory file is not following AntaInventory Schema.
"""
if inventory_input.ranges is None:
return

for range_def in inventory_input.ranges:
try:
range_increment = ip_address(str(range_def.start))
range_stop = ip_address(str(range_def.end))
while range_increment <= range_stop: # type: ignore[operator]
# mypy raise an issue about comparing IPv4Address and IPv6Address
# but this is handled by the ipaddress module natively by raising a TypeError
device = AsyncEOSDevice(host=str(range_increment), tags=range_def.tags, **kwargs)
inventory.add_device(device)
range_increment += 1
except ValueError as e:
message = f"Could not parse the following range in the inventory: {range_def.start} - {range_def.end}"
anta_log_exception(e, message, logger)
raise InventoryIncorrectSchema(message) from e
except TypeError as e:
message = f"A range in the inventory has different address families for start and end: {range_def.start} - {range_def.end}"
anta_log_exception(e, message, logger)
raise InventoryIncorrectSchema(message) from e

@staticmethod
def parse(
inventory_file: str, username: str, password: str, enable_password: Optional[str] = None, timeout: Optional[float] = None, insecure: bool = False
Expand Down Expand Up @@ -81,23 +155,9 @@ def parse(
raise InventoryIncorrectSchema(f"Inventory is not following the schema: {str(exc)}") from exc

# Read data from input
if inventory_input.hosts is not None:
for host in inventory_input.hosts:
device = AsyncEOSDevice(name=host.name, host=str(host.host), port=host.port, tags=host.tags, **kwargs)
inventory.add_device(device)
if inventory_input.networks is not None:
for network in inventory_input.networks:
for host_ip in IPNetwork(str(network.network)):
device = AsyncEOSDevice(host=str(host_ip), tags=network.tags, **kwargs)
inventory.add_device(device)
if inventory_input.ranges is not None:
for range_def in inventory_input.ranges:
range_increment = IPAddress(str(range_def.start))
range_stop = IPAddress(str(range_def.end))
while range_increment <= range_stop:
device = AsyncEOSDevice(host=str(range_increment), tags=range_def.tags, **kwargs)
inventory.add_device(device)
range_increment += 1
AntaInventory._parse_hosts(inventory_input, inventory, **kwargs)
AntaInventory._parse_networks(inventory_input, inventory, **kwargs)
AntaInventory._parse_ranges(inventory_input, inventory, **kwargs)

return inventory

Expand Down Expand Up @@ -167,7 +227,4 @@ async def connect_inventory(self) -> None:
for r in results:
if isinstance(r, Exception):
message = "Error when refreshing inventory"
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(f"{message}: {exc_to_str(r)}")
anta_log_exception(r, message, logger)
15 changes: 4 additions & 11 deletions anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from pydantic import BaseModel, ConfigDict, conint
from rich.progress import Progress, TaskID

from anta import __DEBUG__
from anta.result_manager.models import TestResult
from anta.tools.misc import exc_to_str
from anta.tools.misc import anta_log_exception, exc_to_str

if TYPE_CHECKING:
from anta.device import AntaDevice
Expand Down Expand Up @@ -231,10 +230,7 @@ async def collect(self) -> None:
await self.device.collect_commands(self.instance_commands)
except Exception as e: # pylint: disable=broad-exception-caught
message = f"Exception raised while collecting commands for test {self.name} (on device {self.device.name})"
if __DEBUG__:
self.logger.exception(message)
else:
self.logger.error(f"{message}: {exc_to_str(e)}")
anta_log_exception(e, message, self.logger)
self.result.is_error(exc_to_str(e))

@staticmethod
Expand All @@ -247,7 +243,7 @@ def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]:
async def wrapper(
self: AntaTest,
eos_data: list[dict[Any, Any] | str] | None = None,
**kwargs: dict[str, Any],
**kwargs: Any,
) -> TestResult:
"""
Wraps the test function and implement (in this order):
Expand Down Expand Up @@ -284,10 +280,7 @@ async def wrapper(
function(self, **kwargs)
except Exception as e: # pylint: disable=broad-exception-caught
message = f"Exception raised for test {self.name} (on device {self.device.name})"
if __DEBUG__:
self.logger.exception(message)
else:
self.logger.error(f"{message}: {exc_to_str(e)}")
anta_log_exception(e, message, self.logger)
self.result.is_error(exc_to_str(e))

AntaTest.update_progress()
Expand Down
4 changes: 2 additions & 2 deletions anta/result_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Iterator, List

from pydantic import BaseModel, RootModel, validator
from pydantic import BaseModel, RootModel, field_validator

RESULT_OPTIONS = ["unset", "success", "failure", "error", "skipped"]

Expand All @@ -28,7 +28,7 @@ class TestResult(BaseModel):
messages: List[str] = []

@classmethod
@validator("result", allow_reuse=True)
@field_validator("result")
def name_must_be_in(cls, v: str) -> str:
"""
Status validator
Expand Down
13 changes: 3 additions & 10 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple

from anta import __DEBUG__
from anta.inventory import AntaInventory
from anta.models import AntaTest
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult
from anta.tools.misc import exc_to_str
from anta.tools.misc import anta_log_exception

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,10 +63,7 @@ async def main(
coros.append(test_instance.test(eos_data=None, **test_params))
except Exception as e: # pylint: disable=broad-exception-caught
message = "Error when creating ANTA tests"
if __DEBUG__:
logger.exception(message)
else:
logger.error(f"{message}: {exc_to_str(e)}")
anta_log_exception(e, message, logger)

if AntaTest.progress is not None:
AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coros))
Expand All @@ -77,9 +73,6 @@ async def main(
for r in res:
if isinstance(r, Exception):
message = "Error in main ANTA Runner"
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(f"{message}: {exc_to_str(r)}")
anta_log_exception(r, message, logger)
res.remove(r)
manager.add_test_results(res)
Loading

0 comments on commit d4c4053

Please sign in to comment.