Skip to content

Commit

Permalink
Update unit tests for ResultManager
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-baillargeon committed Jul 11, 2024
1 parent 30c8d7e commit 73208a5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 72 deletions.
32 changes: 31 additions & 1 deletion tests/lib/fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

from __future__ import annotations

import json
import logging
import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from unittest.mock import patch

Expand All @@ -23,12 +25,15 @@

if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path

from anta.models import AntaCommand

logger = logging.getLogger(__name__)

DATA_DIR: Path = Path(__file__).parent.parent.resolve() / "data"

JSON_RESULTS = "test_md_report_results.json"

DEVICE_HW_MODEL = "pytest"
DEVICE_NAME = "pytest"
COMMAND_OUTPUT = "retrieved"
Expand Down Expand Up @@ -154,6 +159,31 @@ def _factory(number: int = 0) -> ResultManager:
return _factory


@pytest.fixture()
def result_manager() -> ResultManager:
"""Return a ResultManager with 89 random tests loaded from a JSON file.
Devices: DC1-SPINE1, DC1-LEAF1A
- Total tests: 89
- Success: 31
- Skipped: 8
- Failure: 48
- Error: 2
See `tests/data/test_md_report_results.json` and `tests/data/test_md_report_all_tests.md` for details.
"""
manager = ResultManager()

with (DATA_DIR / JSON_RESULTS).open("r", encoding="utf-8") as f:
results = json.load(f)

for result in results:
manager.add(TestResult(**result))

return manager


# tests.units.cli fixtures
@pytest.fixture()
def temp_env(tmp_path: Path) -> dict[str, str | None]:
Expand Down
15 changes: 2 additions & 13 deletions tests/units/reporter/test_md_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@

from __future__ import annotations

import json
from io import StringIO
from pathlib import Path

import pytest

from anta.reporter.md_reporter import MDReportBase, MDReportGenerator
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult as FakeTestResult

DATA_DIR: Path = Path(__file__).parent.parent.parent.resolve() / "data"

Expand All @@ -25,22 +23,13 @@
pytest.param(False, "test_md_report_all_tests.md", id="all_tests"),
],
)
def test_md_report_generate(tmp_path: Path, expected_report_name: str, *, only_failed_tests: bool) -> None:
def test_md_report_generate(tmp_path: Path, result_manager: ResultManager, expected_report_name: str, *, only_failed_tests: bool) -> None:
"""Test the MDReportGenerator class."""
# Create a temporary Markdown file
md_filename = tmp_path / "test.md"

manager = ResultManager()

# Load JSON results into the manager
with (DATA_DIR / "test_md_report_results.json").open("r", encoding="utf-8") as f:
results = json.load(f)

for result in results:
manager.add(FakeTestResult(**result))

# Generate the Markdown report
MDReportGenerator.generate(manager, md_filename, only_failed_tests=only_failed_tests)
MDReportGenerator.generate(result_manager, md_filename, only_failed_tests=only_failed_tests)
assert md_filename.exists()

# Load the existing Markdown report to compare with the generated one
Expand Down
93 changes: 35 additions & 58 deletions tests/units/result_manager/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,54 +177,40 @@ def test_add_type_error(self) -> None:
with pytest.raises(TypeError, match="Added test result 'test' must be a TestResult instance, got str."):
result_manager.add("test") # type: ignore[arg-type]

def test_get_results(self, test_result_factory: Callable[[], TestResult]) -> None:
def test_get_results(self, result_manager: ResultManager) -> None:
"""Test ResultManager.get_results."""
result_manager = ResultManager()
assert result_manager.get_results() == []

success_result = test_result_factory()
success_result.result = "success"
success_result.categories = ["ospf"]
result_manager.add(success_result)

failure_result = test_result_factory()
failure_result.result = "failure"
failure_result.categories = ["bgp"]
result_manager.add(failure_result)

skipped_result = test_result_factory()
skipped_result.result = "skipped"
result_manager.add(skipped_result)

error_result = test_result_factory()
error_result.result = "error"
result_manager.add(error_result)

# Check for single status
success_results = result_manager.get_results(status="success")
assert len(success_results) == 1
assert success_results[0].result == "success"
assert len(success_results) == 31
assert all(r.result == "success" for r in success_results)

# Check for multiple statuses
success_failure_results = result_manager.get_results(status={"success", "failure"})
assert len(success_failure_results) == 2
assert all(r.result in {"success", "failure"} for r in success_failure_results)
failure_results = result_manager.get_results(status={"failure", "error"})
assert len(failure_results) == 50
assert all(r.result in {"failure", "error"} for r in failure_results)

# Check all results
all_results = result_manager.get_results()
assert len(all_results) == 4
assert [r.result for r in all_results] == ["success", "failure", "skipped", "error"]
assert len(all_results) == 89

def test_get_results_sort_by(self, result_manager: ResultManager) -> None:
"""Test ResultManager.get_results with sort_by."""
# Check all results with sort_by result
all_results = result_manager.get_results(sort_by=["result"])
assert len(all_results) == 4
assert [r.result for r in all_results] == ["error", "failure", "skipped", "success"]
assert len(all_results) == 89
assert [r.result for r in all_results] == ["error"] * 2 + ["failure"] * 48 + ["skipped"] * 8 + ["success"] * 31

# Check all results with sort_by device (name)
all_results = result_manager.get_results(sort_by=["name"])
assert len(all_results) == 89
assert all_results[0].name == "DC1-LEAF1A"
assert all_results[-1].name == "DC1-SPINE1"

# Check multiple statuses with sort_by categories
success_failure_results = result_manager.get_results(status={"success", "failure"}, sort_by=["categories"])
assert len(success_failure_results) == 2
assert success_failure_results[0] == failure_result
assert success_failure_results[1] == success_result
success_skipped_results = result_manager.get_results(status={"success", "skipped"}, sort_by=["categories"])
assert len(success_skipped_results) == 39
assert success_skipped_results[0].categories == ["BFD"]
assert success_skipped_results[-1].categories == ["VXLAN"]

# Check all results with bad sort_by
with pytest.raises(
Expand All @@ -235,30 +221,21 @@ def test_get_results(self, test_result_factory: Callable[[], TestResult]) -> Non
):
all_results = result_manager.get_results(sort_by=["bad_field"])

def test_get_total_results(self, test_result_factory: Callable[[], TestResult]) -> None:
def test_get_total_results(self, result_manager: ResultManager) -> None:
"""Test ResultManager.get_total_results."""
result_manager = ResultManager()
assert result_manager.get_total_results() == 0

success_result = test_result_factory()
success_result.result = "success"
result_manager.add(success_result)

failure_result = test_result_factory()
failure_result.result = "failure"
result_manager.add(failure_result)

skipped_result = test_result_factory()
skipped_result.result = "skipped"
result_manager.add(skipped_result)

error_result = test_result_factory()
error_result.result = "error"
result_manager.add(error_result)

assert result_manager.get_total_results(status="success") == 1
assert result_manager.get_total_results(status={"success", "failure"}) == 2
assert result_manager.get_total_results() == 4
# Test all results
assert result_manager.get_total_results() == 89

# Test single status
assert result_manager.get_total_results(status="success") == 31
assert result_manager.get_total_results(status="failure") == 48
assert result_manager.get_total_results(status="error") == 2
assert result_manager.get_total_results(status="skipped") == 8

# Test multiple statuses
assert result_manager.get_total_results(status={"success", "failure"}) == 79
assert result_manager.get_total_results(status={"success", "failure", "error"}) == 81
assert result_manager.get_total_results(status={"success", "failure", "error", "skipped"}) == 89

@pytest.mark.parametrize(
("status", "error_status", "ignore_error", "expected_status"),
Expand Down

0 comments on commit 73208a5

Please sign in to comment.