diff --git a/anta/catalog.py b/anta/catalog.py index 30bd34066..7ed4bc718 100644 --- a/anta/catalog.py +++ b/anta/catalog.py @@ -10,9 +10,11 @@ import math from collections import defaultdict from inspect import isclass +from itertools import chain from json import load as json_load from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from warnings import warn from pydantic import BaseModel, ConfigDict, RootModel, ValidationError, ValidationInfo, field_validator, model_serializer, model_validator from pydantic.types import ImportString @@ -386,6 +388,21 @@ def from_list(data: ListAntaTestTuples) -> AntaCatalog: raise return AntaCatalog(tests) + @classmethod + def merge_catalogs(cls, catalogs: list[AntaCatalog]) -> AntaCatalog: + """Merge multiple AntaCatalog instances. + + Parameters + ---------- + catalogs: A list of AntaCatalog instances to merge. + + Returns + ------- + A new AntaCatalog instance containing the tests of all the input catalogs. + """ + combined_tests = list(chain(*(catalog.tests for catalog in catalogs))) + return cls(tests=combined_tests) + def merge(self, catalog: AntaCatalog) -> AntaCatalog: """Merge two AntaCatalog instances. @@ -397,7 +414,13 @@ def merge(self, catalog: AntaCatalog) -> AntaCatalog: ------- A new AntaCatalog instance containing the tests of the two instances. """ - return AntaCatalog(tests=self.tests + catalog.tests) + # TODO: Use a decorator to deprecate this method instead. See https://github.com/aristanetworks/anta/issues/754 + warn( + message="AntaCatalog.merge() is deprecated and will be removed in ANTA v2.0. Use AntaCatalog.merge_catalogs() instead.", + category=DeprecationWarning, + stacklevel=2, + ) + return self.merge_catalogs([self, catalog]) def dump(self) -> AntaCatalogFile: """Return an AntaCatalogFile instance from this AntaCatalog instance. diff --git a/docs/usage-inventory-catalog.md b/docs/usage-inventory-catalog.md index fd6aec320..5ae4cc923 100644 --- a/docs/usage-inventory-catalog.md +++ b/docs/usage-inventory-catalog.md @@ -309,7 +309,7 @@ Once you run `anta nrfu table`, you will see following output: ### Example script to merge catalogs -The following script reads all the files in `intended/test_catalogs/` with names `-catalog.yml` and merge them together inside one big catalog `anta-catalog.yml`. +The following script reads all the files in `intended/test_catalogs/` with names `-catalog.yml` and merge them together inside one big catalog `anta-catalog.yml` using the new `AntaCatalog.merge_catalogs()` class method. ```python #!/usr/bin/env python @@ -319,19 +319,26 @@ from pathlib import Path from anta.models import AntaTest -CATALOG_SUFFIX = '-catalog.yml' -CATALOG_DIR = 'intended/test_catalogs/' +CATALOG_SUFFIX = "-catalog.yml" +CATALOG_DIR = "intended/test_catalogs/" if __name__ == "__main__": - catalog = AntaCatalog() - for file in Path(CATALOG_DIR).glob('*'+CATALOG_SUFFIX): - c = AntaCatalog.parse(file) + catalogs = [] + for file in Path(CATALOG_DIR).glob("*" + CATALOG_SUFFIX): device = str(file).removesuffix(CATALOG_SUFFIX).removeprefix(CATALOG_DIR) - print(f"Merging test catalog for device {device}") - # Apply filters to all tests for this device - for test in c.tests: - test.inputs.filters = AntaTest.Input.Filters(tags=[device]) - catalog = catalog.merge(c) + print(f"Loading test catalog for device {device}") + catalog = AntaCatalog.parse(file) + # Add the device name as a tag to all tests in the catalog + for test in catalog.tests: + test.inputs.filters = AntaTest.Input.Filters(tags={device}) + catalogs.append(catalog) + + # Merge all catalogs + merged_catalog = AntaCatalog.merge_catalogs(catalogs) + + # Save the merged catalog to a file with open(Path('anta-catalog.yml'), "w") as f: f.write(catalog.dump().yaml()) ``` +!!! warning + The `AntaCatalog.merge()` method is deprecated and will be removed in ANTA v2.0. Please use the `AntaCatalog.merge_catalogs()` class method instead. diff --git a/tests/units/test_catalog.py b/tests/units/test_catalog.py index 76358dd4a..13046f294 100644 --- a/tests/units/test_catalog.py +++ b/tests/units/test_catalog.py @@ -345,6 +345,17 @@ def test_get_tests_by_tags(self) -> None: tests = catalog.get_tests_by_tags(tags={"leaf", "spine"}, strict=True) assert len(tests) == 1 + def test_merge_catalogs(self) -> None: + """Test the merge_catalogs function.""" + # Load catalogs of different sizes + small_catalog = AntaCatalog.parse(DATA_DIR / "test_catalog.yml") + medium_catalog = AntaCatalog.parse(DATA_DIR / "test_catalog_medium.yml") + tagged_catalog = AntaCatalog.parse(DATA_DIR / "test_catalog_with_tags.yml") + + # Merge the catalogs and check the number of tests + final_catalog = AntaCatalog.merge_catalogs([small_catalog, medium_catalog, tagged_catalog]) + assert len(final_catalog.tests) == len(small_catalog.tests) + len(medium_catalog.tests) + len(tagged_catalog.tests) + def test_merge(self) -> None: """Test AntaCatalog.merge().""" catalog1: AntaCatalog = AntaCatalog.parse(DATA_DIR / "test_catalog.yml") @@ -354,11 +365,15 @@ def test_merge(self) -> None: catalog3: AntaCatalog = AntaCatalog.parse(DATA_DIR / "test_catalog_medium.yml") assert len(catalog3.tests) == 228 - assert len(catalog1.merge(catalog2).tests) == 2 + with pytest.deprecated_call(): + merged_catalog = catalog1.merge(catalog2) + assert len(merged_catalog.tests) == 2 assert len(catalog1.tests) == 1 assert len(catalog2.tests) == 1 - assert len(catalog2.merge(catalog3).tests) == 229 + with pytest.deprecated_call(): + merged_catalog = catalog2.merge(catalog3) + assert len(merged_catalog.tests) == 229 assert len(catalog2.tests) == 1 assert len(catalog3.tests) == 228