diff --git a/tests/conftest.py b/tests/conftest.py index 8383feaf..5f391cb2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. -"""Test fixtures""" +"""Common test fixtures.""" import argparse +import logging import os import pathlib from tempfile import TemporaryDirectory @@ -27,6 +28,7 @@ from trestle.common.err import TrestleError from trestle.core.commands.init import InitCmd +from tests.testutils import clean from trestlebot import const from trestlebot.transformers.trestle_rule import ( ComponentInfo, @@ -38,21 +40,21 @@ T = TypeVar("T") - YieldFixture = Generator[T, None, None] _TEST_CONTENTS = """ test file """ - _TEST_FILENAME = "test.txt" +_TEST_PREFIX = "trestlebot_tests" -@pytest.fixture +@pytest.fixture(scope="function") def tmp_repo() -> YieldFixture[Tuple[str, Repo]]: """Create a temporary git repository""" - with TemporaryDirectory(prefix="trestlebot_tests") as tmpdir: - with open(os.path.join(tmpdir, _TEST_FILENAME), "x", encoding="utf8") as file: + with TemporaryDirectory(prefix=_TEST_PREFIX) as tmpdir: + test_file = os.path.join(tmpdir, _TEST_FILENAME) + with open(test_file, "x", encoding="utf8") as file: file.write(_TEST_CONTENTS) repo = Repo.init(tmpdir) with repo.config_writer() as config: @@ -62,11 +64,16 @@ def tmp_repo() -> YieldFixture[Tuple[str, Repo]]: repo.index.commit("Initial commit") yield tmpdir, repo + try: + clean(tmpdir, repo) + except Exception as e: + logging.error(f"Failed to clean up temporary git repository: {e}") + -@pytest.fixture +@pytest.fixture(scope="function") def tmp_trestle_dir() -> YieldFixture[str]: """Create an initialized temporary trestle directory""" - with TemporaryDirectory(prefix="trestlebot_tests") as tmpdir: + with TemporaryDirectory(prefix=_TEST_PREFIX) as tmpdir: tmp_path = pathlib.Path(tmpdir) try: args = argparse.Namespace( diff --git a/tests/trestlebot/test_bot.py b/tests/trestlebot/test_bot.py index eeea6c36..c2876648 100644 --- a/tests/trestlebot/test_bot.py +++ b/tests/trestlebot/test_bot.py @@ -25,8 +25,8 @@ from git.repo import Repo import trestlebot.bot as bot -from tests.testutils import clean from trestlebot.provider import GitProvider, GitProviderException +from trestlebot.tasks.base_task import TaskBase, TaskException def check_lists_equal(list1: List[str], list2: List[str]) -> bool: @@ -67,8 +67,6 @@ def test_stage_files( assert check_lists_equal(staged_files, expected_files) is True - clean(repo_path, repo) - def test_local_commit(tmp_repo: Tuple[str, Repo]) -> None: """Test local commit function""" @@ -99,8 +97,6 @@ def test_local_commit(tmp_repo: Tuple[str, Repo]) -> None: # Verify that the file is tracked by the commit assert os.path.basename(test_file_path) in commit.stats.files - clean(repo_path, repo) - def test_local_commit_with_committer(tmp_repo: Tuple[str, Repo]) -> None: """Test setting committer information for commits""" @@ -132,8 +128,6 @@ def test_local_commit_with_committer(tmp_repo: Tuple[str, Repo]) -> None: # Verify that the file is tracked by the commit assert os.path.basename(test_file_path) in commit.stats.files - clean(repo_path, repo) - def test_local_commit_with_author(tmp_repo: Tuple[str, Repo]) -> None: """Test setting author for commits""" @@ -166,8 +160,6 @@ def test_local_commit_with_author(tmp_repo: Tuple[str, Repo]) -> None: # Verify that the file is tracked by the commit assert os.path.basename(test_file_path) in commit.stats.files - clean(repo_path, repo) - def test_run(tmp_repo: Tuple[str, Repo]) -> None: """Test bot run with mocked push""" @@ -208,8 +200,6 @@ def test_run(tmp_repo: Tuple[str, Repo]) -> None: # Verify that the file is tracked by the commit assert os.path.basename(test_file_path) in commit.stats.files - clean(repo_path, repo) - def test_run_dry_run(tmp_repo: Tuple[str, Repo]) -> None: """Test bot run with dry run""" @@ -240,8 +230,6 @@ def test_run_dry_run(tmp_repo: Tuple[str, Repo]) -> None: mock_push.assert_not_called() - clean(repo_path, repo) - def test_empty_commit(tmp_repo: Tuple[str, Repo]) -> None: """Test running bot with no file updates""" @@ -262,8 +250,6 @@ def test_empty_commit(tmp_repo: Tuple[str, Repo]) -> None: assert commit_sha == "" assert pr_number == 0 - clean(repo_path, repo) - def test_run_check_only(tmp_repo: Tuple[str, Repo]) -> None: """Test bot run with check_only""" @@ -291,8 +277,6 @@ def test_run_check_only(tmp_repo: Tuple[str, Repo]) -> None: check_only=True, ) - clean(repo_path, repo) - def push_side_effect(refspec: str) -> None: raise GitCommandError("example") @@ -338,7 +322,34 @@ def test_run_with_exception( dry_run=False, ) - clean(repo_path, repo) + +def test_run_with_failed_pre_task(tmp_repo: Tuple[str, Repo]) -> None: + """Test bot run with mocked task that fails""" + repo_path, repo = tmp_repo + + # Create a test file + test_file_path = os.path.join(repo_path, "test.txt") + with open(test_file_path, "w") as f: + f.write("Test content") + + mock = Mock(spec=TaskBase) + mock.execute.side_effect = TaskException("example") + + repo.create_remote("origin", url="git.test.com/test/repo.git") + + with pytest.raises(bot.RepoException, match="Bot pre-tasks failed: example"): + _ = bot.run( + working_dir=repo_path, + branch="main", + commit_name="Test User", + commit_email="test@example.com", + commit_message="Test commit message", + author_name="The Author", + author_email="author@test.com", + patterns=["*.txt"], + dry_run=True, + pre_tasks=[mock], + ) def test_run_with_provider(tmp_repo: Tuple[str, Repo]) -> None: @@ -396,8 +407,6 @@ def test_run_with_provider(tmp_repo: Tuple[str, Repo]) -> None: ) mock_push.assert_called_once_with(refspec="HEAD:test") - clean(repo_path, repo) - def test_run_with_provider_with_custom_pr_title(tmp_repo: Tuple[str, Repo]) -> None: """Test bot run with customer pull request title""" @@ -453,5 +462,3 @@ def test_run_with_provider_with_custom_pr_title(tmp_repo: Tuple[str, Repo]) -> N body="", ) mock_push.assert_called_once_with(refspec="HEAD:test") - - clean(repo_path, repo) diff --git a/tests/trestlebot/test_cli.py b/tests/trestlebot/test_cli.py index f883891b..5440dc88 100644 --- a/tests/trestlebot/test_cli.py +++ b/tests/trestlebot/test_cli.py @@ -17,7 +17,6 @@ """Test for CLI""" import logging -import sys from typing import Any, Dict, List from unittest.mock import patch @@ -48,36 +47,24 @@ def args_dict_to_list(args_dict: Dict[str, str]) -> List[str]: return args -def test_invalid_oscal_model( - monkeypatch: Any, valid_args_dict: Dict[str, str], caplog: Any -) -> None: +def test_invalid_oscal_model(valid_args_dict: Dict[str, str], caplog: Any) -> None: """Test invalid oscal model""" args_dict = valid_args_dict args_dict["oscal-model"] = "fake" - monkeypatch.setattr(sys, "argv", ["trestlebot", *args_dict_to_list(args_dict)]) - with pytest.raises(SystemExit): - cli_main() - - assert any( - record.levelno == logging.ERROR - and record.message - == "Invalid value fake for oscal model. Please use catalog, profile, compdef, or ssp." - for record in caplog.records - ) + with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]): + with pytest.raises(SystemExit, match="2"): + cli_main() -def test_no_ssp_index( - monkeypatch: Any, valid_args_dict: Dict[str, str], caplog: Any -) -> None: +def test_no_ssp_index(valid_args_dict: Dict[str, str], caplog: Any) -> None: """Test missing index file for ssp""" args_dict = valid_args_dict args_dict["oscal-model"] = "ssp" args_dict["ssp-index-path"] = "" - monkeypatch.setattr(sys, "argv", ["trestlebot", *args_dict_to_list(args_dict)]) - - with pytest.raises(SystemExit): - cli_main() + with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]): + with pytest.raises(SystemExit): + cli_main() assert any( record.levelno == logging.ERROR @@ -86,16 +73,13 @@ def test_no_ssp_index( ) -def test_no_markdown_path( - monkeypatch: Any, valid_args_dict: Dict[str, str], caplog: Any -) -> None: +def test_no_markdown_path(valid_args_dict: Dict[str, str], caplog: Any) -> None: """Test without a markdown file passed as a flag""" args_dict = valid_args_dict args_dict["markdown-path"] = "" - monkeypatch.setattr(sys, "argv", ["trestlebot", *args_dict_to_list(args_dict)]) - - with pytest.raises(SystemExit): - cli_main() + with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]): + with pytest.raises(SystemExit): + cli_main() assert any( record.levelno == logging.ERROR @@ -104,17 +88,16 @@ def test_no_markdown_path( ) -def test_with_target_branch( - monkeypatch: Any, valid_args_dict: Dict[str, str], caplog: Any -) -> None: +def test_with_target_branch(valid_args_dict: Dict[str, str], caplog: Any) -> None: """Test with target branch set an an unsupported Git provider""" args_dict = valid_args_dict args_dict["target-branch"] = "main" - monkeypatch.setattr(sys, "argv", ["trestlebot", *args_dict_to_list(args_dict)]) # Patch is_github_actions since these tests will be running in # GitHub Actions - with patch("trestlebot.cli_base.is_github_actions") as mock_check: + with patch("trestlebot.cli_base.is_github_actions") as mock_check, patch( + "sys.argv", ["trestlebot", *args_dict_to_list(args_dict)] + ): mock_check.return_value = False with pytest.raises(SystemExit): diff --git a/tests/trestlebot/transformers/test_csv_transformer.py b/tests/trestlebot/transformers/test_csv_transformer.py index 6dc61bb4..61f617eb 100644 --- a/tests/trestlebot/transformers/test_csv_transformer.py +++ b/tests/trestlebot/transformers/test_csv_transformer.py @@ -22,7 +22,11 @@ import pytest -from trestlebot.transformers.csv_transformer import CSVBuilder +from trestlebot.transformers.csv_transformer import ( + CSVBuilder, + FromRulesCSVTransformer, + ToRulesCSVTransformer, +) from trestlebot.transformers.trestle_rule import TrestleRule @@ -68,3 +72,14 @@ def test_validate_row() -> None: csv_builder = CSVBuilder() with pytest.raises(RuntimeError, match="Row missing key: *"): csv_builder.validate_row(row) + + +def test_read_write_integration(test_rule: TrestleRule) -> None: + """Test read/write integration.""" + from_rules_transformer = FromRulesCSVTransformer() + to_rules_transformer = ToRulesCSVTransformer() + + csv_row_data = from_rules_transformer.transform(test_rule) + read_rule = to_rules_transformer.transform(csv_row_data) + + assert read_rule == test_rule diff --git a/trestlebot/cli.py b/trestlebot/cli.py index e5d8b194..b5f0ffef 100644 --- a/trestlebot/cli.py +++ b/trestlebot/cli.py @@ -56,7 +56,8 @@ def setup_autosync_arguments(self) -> None: "--oscal-model", required=True, type=str, - help="OSCAL model type to run tasks on. Values can be catalog, profile, compdef, or ssp", + choices=["catalog", "profile", "compdef", "ssp"], + help="OSCAL model type to run tasks on.", ) self.parser.add_argument( "--skip-items",