From 4ca69589fc201cd17570511f9a83d40038cbf3d7 Mon Sep 17 00:00:00 2001 From: Jennifer Power Date: Wed, 13 Dec 2023 11:46:42 -0500 Subject: [PATCH] refactor: streamlines exit exception handling (#103) * feat: adds guard during AuthoredObject creation to verify trestle root exists Signed-off-by: Jennifer Power * refactor: streamlines error handling for input validation Adds a custom exception and adds a try/except block to each entrypoint run function to centralize exit error handling to `handle_exception`. Signed-off-by: Jennifer Power --------- Signed-off-by: Jennifer Power --- tests/trestlebot/entrypoints/test_autosync.py | 35 ++++-- trestlebot/entrypoints/autosync.py | 103 ++++++++++-------- trestlebot/entrypoints/create_cd.py | 100 +++++++++-------- trestlebot/entrypoints/entrypoint_base.py | 95 ++++++++-------- trestlebot/entrypoints/rule_transform.py | 61 +++++++---- trestlebot/tasks/authored/base_authored.py | 3 + 6 files changed, 232 insertions(+), 165 deletions(-) diff --git a/tests/trestlebot/entrypoints/test_autosync.py b/tests/trestlebot/entrypoints/test_autosync.py index 1a71957a..1120da2a 100644 --- a/tests/trestlebot/entrypoints/test_autosync.py +++ b/tests/trestlebot/entrypoints/test_autosync.py @@ -14,7 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -"""Test for CLI""" +"""Test for Autosync CLI""" import logging from typing import Any, Dict @@ -34,7 +34,7 @@ def valid_args_dict() -> Dict[str, str]: "oscal-model": "profile", "committer-name": "test", "committer-email": "test@email.com", - "working-dir": "tmp", + "working-dir": ".", "file-patterns": ".", } @@ -54,12 +54,13 @@ def test_no_ssp_index(valid_args_dict: Dict[str, str], caplog: Any) -> None: args_dict["oscal-model"] = "ssp" args_dict["ssp-index-path"] = "" with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]): - with pytest.raises(SystemExit): + with pytest.raises(SystemExit, match="2"): cli_main() assert any( record.levelno == logging.ERROR - and record.message == "Must set ssp_index_path when using SSP as oscal model." + and "Invalid args --ssp-index-path: Must set ssp index path when using SSP as " + "oscal model." in record.message for record in caplog.records ) @@ -69,12 +70,27 @@ def test_no_markdown_path(valid_args_dict: Dict[str, str], caplog: Any) -> None: args_dict = valid_args_dict args_dict["markdown-path"] = "" with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]): - with pytest.raises(SystemExit): + with pytest.raises(SystemExit, match="2"): cli_main() assert any( record.levelno == logging.ERROR - and record.message == "Must set markdown path with oscal model." + and "Invalid args --markdown-path: Markdown path must be set." in record.message + for record in caplog.records + ) + + +def test_non_existent_working_dir(valid_args_dict: Dict[str, str], caplog: Any) -> None: + """Test with a non-existent working directory""" + args_dict = valid_args_dict + args_dict["working-dir"] = "tmp" + with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]): + with pytest.raises(SystemExit, match="1"): + cli_main() + + assert any( + record.levelno == logging.ERROR + and "Root path tmp does not exist" in record.message for record in caplog.records ) @@ -91,13 +107,14 @@ def test_with_target_branch(valid_args_dict: Dict[str, str], caplog: Any) -> Non ) as mock_check, patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]): mock_check.return_value = False - with pytest.raises(SystemExit): + with pytest.raises(SystemExit, match="2"): cli_main() assert any( record.levelno == logging.ERROR - and record.message == "target-branch flag is set with an unset git provider. " - "To test locally, set the GITHUB_ACTIONS or GITLAB_CI environment variable." + and "Invalid args --target-branch: target-branch flag is set with an " + "unset git provider. To test locally, set the GITHUB_ACTIONS or GITLAB_CI environment variable." + in record.message for record in caplog.records ) diff --git a/trestlebot/entrypoints/autosync.py b/trestlebot/entrypoints/autosync.py index 51221933..4e2710f8 100644 --- a/trestlebot/entrypoints/autosync.py +++ b/trestlebot/entrypoints/autosync.py @@ -27,8 +27,13 @@ import sys from typing import List -from trestlebot import const -from trestlebot.entrypoints.entrypoint_base import EntrypointBase, comma_sep_to_list +from trestlebot.const import SUCCESS_EXIT_CODE +from trestlebot.entrypoints.entrypoint_base import ( + EntrypointBase, + EntrypointInvalidArgException, + comma_sep_to_list, + handle_exception, +) from trestlebot.entrypoints.log import set_log_level_from_args from trestlebot.tasks.assemble_task import AssembleTask from trestlebot.tasks.authored import types @@ -95,63 +100,71 @@ def validate_args(self, args: argparse.Namespace) -> None: """Validate the arguments for the autosync entrypoint.""" authored_list: List[str] = [model.value for model in types.AuthoredType] if args.oscal_model not in authored_list: - logger.error( - f"Invalid value {args.oscal_model} for oscal model. " - f"Please use one of {authored_list}" + raise EntrypointInvalidArgException( + "--oscal-model", + f"Value {args.oscal_model} is not valid." + f"Please use one of {authored_list}", ) - sys.exit(const.ERROR_EXIT_CODE) if not args.markdown_path: - logger.error("Must set markdown path with oscal model.") - sys.exit(const.ERROR_EXIT_CODE) + raise EntrypointInvalidArgException( + "--markdown-path", "Markdown path must be set." + ) if ( args.oscal_model == types.AuthoredType.SSP.value and args.ssp_index_path == "" ): - logger.error("Must set ssp_index_path when using SSP as oscal model.") - sys.exit(const.ERROR_EXIT_CODE) + raise EntrypointInvalidArgException( + "--ssp-index-path", + "Must set ssp index path when using SSP as oscal model.", + ) def run(self, args: argparse.Namespace) -> None: """Run the autosync entrypoint.""" - - set_log_level_from_args(args) - self.validate_args(args) - - pre_tasks: List[TaskBase] = [] - # Allow any model to be skipped from the args, by default include all - model_filter: ModelFilter = ModelFilter( - skip_patterns=comma_sep_to_list(args.skip_items), - include_patterns=["*"], - ) - - authored_object: AuthoredObjectBase = types.get_authored_object( - args.oscal_model, args.working_dir, args.ssp_index_path - ) - - # Assuming an edit has occurred assemble would be run before regenerate. - # Adding this to the list first - if not args.skip_assemble: - assemble_task = AssembleTask( - authored_object=authored_object, - markdown_dir=args.markdown_path, - model_filter=model_filter, + exit_code: int = SUCCESS_EXIT_CODE + try: + set_log_level_from_args(args) + self.validate_args(args) + + pre_tasks: List[TaskBase] = [] + # Allow any model to be skipped from the args, by default include all + model_filter: ModelFilter = ModelFilter( + skip_patterns=comma_sep_to_list(args.skip_items), + include_patterns=["*"], ) - pre_tasks.append(assemble_task) - else: - logger.info("Assemble task skipped.") - - if not args.skip_regenerate: - regenerate_task = RegenerateTask( - authored_object=authored_object, - markdown_dir=args.markdown_path, - model_filter=model_filter, + + authored_object: AuthoredObjectBase = types.get_authored_object( + args.oscal_model, args.working_dir, args.ssp_index_path ) - pre_tasks.append(regenerate_task) - else: - logger.info("Regeneration task skipped.") - super().run_base(args, pre_tasks) + # Assuming an edit has occurred assemble would be run before regenerate. + # Adding this to the list first + if not args.skip_assemble: + assemble_task: AssembleTask = AssembleTask( + authored_object=authored_object, + markdown_dir=args.markdown_path, + model_filter=model_filter, + ) + pre_tasks.append(assemble_task) + else: + logger.info("Assemble task skipped.") + + if not args.skip_regenerate: + regenerate_task: RegenerateTask = RegenerateTask( + authored_object=authored_object, + markdown_dir=args.markdown_path, + model_filter=model_filter, + ) + pre_tasks.append(regenerate_task) + else: + logger.info("Regeneration task skipped.") + + super().run_base(args, pre_tasks) + except Exception as e: + exit_code = handle_exception(e) + + sys.exit(exit_code) def main() -> None: diff --git a/trestlebot/entrypoints/create_cd.py b/trestlebot/entrypoints/create_cd.py index ca3ff96b..e1eed527 100644 --- a/trestlebot/entrypoints/create_cd.py +++ b/trestlebot/entrypoints/create_cd.py @@ -24,10 +24,11 @@ import argparse import logging import pathlib +import sys from typing import List, Optional -from trestlebot.const import RULE_PREFIX, RULES_VIEW_DIR -from trestlebot.entrypoints.entrypoint_base import EntrypointBase +from trestlebot.const import RULE_PREFIX, RULES_VIEW_DIR, SUCCESS_EXIT_CODE +from trestlebot.entrypoints.entrypoint_base import EntrypointBase, handle_exception from trestlebot.entrypoints.log import set_log_level_from_args from trestlebot.tasks.authored.compdef import ( AuthoredComponentDefinition, @@ -91,49 +92,58 @@ def setup_create_cd_arguments(self) -> None: def run(self, args: argparse.Namespace) -> None: """Run the entrypoint.""" - - set_log_level_from_args(args) - pre_tasks: List[TaskBase] = [] - filter_by_profile: Optional[FilterByProfile] = None - trestle_root: pathlib.Path = pathlib.Path(args.working_dir) - - if args.filter_by_profile: - filter_by_profile = FilterByProfile(trestle_root, args.filter_by_profile) - - authored_comp = AuthoredComponentDefinition(args.working_dir) - authored_comp.create_new_default( - args.profile_name, - args.compdef_name, - args.component_title, - args.component_description, - args.component_definition_type, - filter_by_profile, - ) - - transformer: ToRulesYAMLTransformer = ToRulesYAMLTransformer() - - # In this case we only want to do the transformation and generation for this component - # definition, so we skip all other component definitions and components. - model_filter: ModelFilter = ModelFilter( - [], [args.compdef_name, args.component_title, f"{RULE_PREFIX}*"] - ) - - rule_transform_task: RuleTransformTask = RuleTransformTask( - working_dir=args.working_dir, - rules_view_dir=RULES_VIEW_DIR, - rule_transformer=transformer, - model_filter=model_filter, - ) - pre_tasks.append(rule_transform_task) - - regenerate_task = RegenerateTask( - authored_object=authored_comp, - markdown_dir=args.markdown_path, - model_filter=model_filter, - ) - pre_tasks.append(regenerate_task) - - super().run_base(args, pre_tasks) + exit_code: int = SUCCESS_EXIT_CODE + try: + set_log_level_from_args(args) + pre_tasks: List[TaskBase] = [] + filter_by_profile: Optional[FilterByProfile] = None + trestle_root: pathlib.Path = pathlib.Path(args.working_dir) + + if args.filter_by_profile: + filter_by_profile = FilterByProfile( + trestle_root, args.filter_by_profile + ) + + authored_comp: AuthoredComponentDefinition = AuthoredComponentDefinition( + args.working_dir + ) + authored_comp.create_new_default( + args.profile_name, + args.compdef_name, + args.component_title, + args.component_description, + args.component_definition_type, + filter_by_profile, + ) + + transformer: ToRulesYAMLTransformer = ToRulesYAMLTransformer() + + # In this case we only want to do the transformation and generation for this component + # definition, so we skip all other component definitions and components. + model_filter: ModelFilter = ModelFilter( + [], [args.compdef_name, args.component_title, f"{RULE_PREFIX}*"] + ) + + rule_transform_task: RuleTransformTask = RuleTransformTask( + working_dir=args.working_dir, + rules_view_dir=RULES_VIEW_DIR, + rule_transformer=transformer, + model_filter=model_filter, + ) + pre_tasks.append(rule_transform_task) + + regenerate_task: RegenerateTask = RegenerateTask( + authored_object=authored_comp, + markdown_dir=args.markdown_path, + model_filter=model_filter, + ) + pre_tasks.append(regenerate_task) + + super().run_base(args, pre_tasks) + except Exception as e: + exit_code = handle_exception(e) + + sys.exit(exit_code) def main() -> None: diff --git a/trestlebot/entrypoints/entrypoint_base.py b/trestlebot/entrypoints/entrypoint_base.py index 32f13c5b..e6a36c9f 100644 --- a/trestlebot/entrypoints/entrypoint_base.py +++ b/trestlebot/entrypoints/entrypoint_base.py @@ -139,13 +139,15 @@ def setup_common_arguments(self) -> None: ) @staticmethod - def run_base(args: argparse.Namespace, pre_tasks: List[TaskBase]) -> None: - """Reusable logic for all entrypoints.""" + def set_git_provider(args: argparse.Namespace) -> Optional[GitProvider]: + """Get the git provider based on the environment and args.""" git_provider: Optional[GitProvider] = None if args.target_branch: if not args.with_token: - logger.error("with-token value cannot be empty") - sys.exit(const.ERROR_EXIT_CODE) + raise EntrypointInvalidArgException( + "--with-token", + "with-token flag must be set when using target-branch", + ) if is_github_actions(): git_provider = GitHub(access_token=args.with_token.read().strip()) @@ -155,49 +157,46 @@ def run_base(args: argparse.Namespace, pre_tasks: List[TaskBase]) -> None: api_token=args.with_token.read().strip(), server_url=server_api_url ) else: - logger.error( + raise EntrypointInvalidArgException( + "--target-branch", ( "target-branch flag is set with an unset git provider. " "To test locally, set the GITHUB_ACTIONS or GITLAB_CI environment variable." - ) + ), ) - sys.exit(const.ERROR_EXIT_CODE) - - exit_code: int = const.SUCCESS_EXIT_CODE - - # Assume it is a successful run, if the bot - # throws an exception update the exit code accordingly - try: - bot = TrestleBot( - working_dir=args.working_dir, - branch=args.branch, - commit_name=args.committer_name, - commit_email=args.committer_email, - author_name=args.author_name, - author_email=args.author_email, - target_branch=args.target_branch, - ) - commit_sha, pr_number = bot.run( - commit_message=args.commit_message, - pre_tasks=pre_tasks, - patterns=comma_sep_to_list(args.file_patterns), - git_provider=git_provider, - pull_request_title=args.pull_request_title, - check_only=args.check_only, - ) - - # Print the full commit sha - if commit_sha: - print(f"Commit Hash: {commit_sha}") # noqa: T201 - - # Print the pr number - if pr_number: - print(f"Pull Request Number: {pr_number}") # noqa: T201 - - except Exception as e: - exit_code = handle_exception(e) - - sys.exit(exit_code) + return git_provider + + def run_base(self, args: argparse.Namespace, pre_tasks: List[TaskBase]) -> None: + """Reusable logic for all entrypoints.""" + + git_provider: Optional[GitProvider] = self.set_git_provider(args) + + # Configure and run the bot + bot = TrestleBot( + working_dir=args.working_dir, + branch=args.branch, + commit_name=args.committer_name, + commit_email=args.committer_email, + author_name=args.author_name, + author_email=args.author_email, + target_branch=args.target_branch, + ) + commit_sha, pr_number = bot.run( + commit_message=args.commit_message, + pre_tasks=pre_tasks, + patterns=comma_sep_to_list(args.file_patterns), + git_provider=git_provider, + pull_request_title=args.pull_request_title, + check_only=args.check_only, + ) + + # Print the full commit sha + if commit_sha: + print(f"Commit Hash: {commit_sha}") # noqa: T201 + + # Print the pr number + if pr_number: + print(f"Pull Request Number: {pr_number}") # noqa: T201 def comma_sep_to_list(string: str) -> List[str]: @@ -206,10 +205,20 @@ def comma_sep_to_list(string: str) -> List[str]: return list(map(str.strip, string.split(","))) if string else [] +class EntrypointInvalidArgException(Exception): + """Custom exception for handling invalid arguments.""" + + def __init__(self, arg: str, msg: str): + super().__init__(f"Invalid args {arg}: {msg}") + + def handle_exception( exception: Exception, msg: str = "Exception occurred during execution" ) -> int: """Log the exception and return the exit code""" logger.error(msg + f": {exception}", exc_info=True) + if isinstance(exception, EntrypointInvalidArgException): + return const.INVALID_ARGS_EXIT_CODE + return const.ERROR_EXIT_CODE diff --git a/trestlebot/entrypoints/rule_transform.py b/trestlebot/entrypoints/rule_transform.py index 2557d27e..d2f420ac 100644 --- a/trestlebot/entrypoints/rule_transform.py +++ b/trestlebot/entrypoints/rule_transform.py @@ -16,9 +16,15 @@ import argparse import logging +import sys from typing import List -from trestlebot.entrypoints.entrypoint_base import EntrypointBase, comma_sep_to_list +from trestlebot.const import SUCCESS_EXIT_CODE +from trestlebot.entrypoints.entrypoint_base import ( + EntrypointBase, + comma_sep_to_list, + handle_exception, +) from trestlebot.entrypoints.log import set_log_level_from_args from trestlebot.tasks.base_task import ModelFilter, TaskBase from trestlebot.tasks.rule_transform_task import RuleTransformTask @@ -56,28 +62,37 @@ def setup_rules_transformation_arguments(self) -> None: def run(self, args: argparse.Namespace) -> None: """Run the rule transform entrypoint.""" - - set_log_level_from_args(args) - - # Configure the YAML Transformer for the task - validation_handler: ValidationHandler = ValidationHandler(parameter_validation) - transformer: ToRulesYAMLTransformer = ToRulesYAMLTransformer(validation_handler) - - # Allow any model to be skipped from the args, by default include all - model_filter: ModelFilter = ModelFilter( - skip_patterns=comma_sep_to_list(args.skip_items), - include_patterns=["*"], - ) - - rule_transform_task: RuleTransformTask = RuleTransformTask( - working_dir=args.working_dir, - rules_view_dir=args.rules_view_path, - rule_transformer=transformer, - model_filter=model_filter, - ) - pre_tasks: List[TaskBase] = [rule_transform_task] - - super().run_base(args, pre_tasks) + exit_code: int = SUCCESS_EXIT_CODE + try: + set_log_level_from_args(args) + + # Configure the YAML Transformer for the task + validation_handler: ValidationHandler = ValidationHandler( + parameter_validation + ) + transformer: ToRulesYAMLTransformer = ToRulesYAMLTransformer( + validation_handler + ) + + # Allow any model to be skipped from the args, by default include all + model_filter: ModelFilter = ModelFilter( + skip_patterns=comma_sep_to_list(args.skip_items), + include_patterns=["*"], + ) + + rule_transform_task: RuleTransformTask = RuleTransformTask( + working_dir=args.working_dir, + rules_view_dir=args.rules_view_path, + rule_transformer=transformer, + model_filter=model_filter, + ) + pre_tasks: List[TaskBase] = [rule_transform_task] + + super().run_base(args, pre_tasks) + except Exception as e: + exit_code = handle_exception(e) + + sys.exit(exit_code) def main() -> None: diff --git a/trestlebot/tasks/authored/base_authored.py b/trestlebot/tasks/authored/base_authored.py index 51ba220a..fbe3f0bd 100644 --- a/trestlebot/tasks/authored/base_authored.py +++ b/trestlebot/tasks/authored/base_authored.py @@ -16,6 +16,7 @@ """Trestle Bot base authored object""" +import os from abc import ABC, abstractmethod @@ -30,6 +31,8 @@ class AuthoredObjectBase(ABC): def __init__(self, trestle_root: str) -> None: """Initialize task base and store trestle root path""" + if not os.path.exists(trestle_root): + raise AuthoredObjectException(f"Root path {trestle_root} does not exist") self._trestle_root = trestle_root def get_trestle_root(self) -> str: