Skip to content

Commit

Permalink
refactor: streamlines exit exception handling (#103)
Browse files Browse the repository at this point in the history
* feat: adds guard during AuthoredObject creation to verify trestle root exists

Signed-off-by: Jennifer Power <[email protected]>

* refactor: streamlines error handling for input validation

Adds a custom exception and adds a try/except block to each
entrypoint run function to centralize exit error handling to
`handle_exception`.

Signed-off-by: Jennifer Power <[email protected]>

---------

Signed-off-by: Jennifer Power <[email protected]>
  • Loading branch information
jpower432 authored Dec 13, 2023
1 parent f6f7035 commit 4ca6958
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 165 deletions.
35 changes: 26 additions & 9 deletions tests/trestlebot/entrypoints/test_autosync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.

"""Test for CLI"""
"""Test for Autosync CLI"""

import logging
from typing import Any, Dict
Expand All @@ -34,7 +34,7 @@ def valid_args_dict() -> Dict[str, str]:
"oscal-model": "profile",
"committer-name": "test",
"committer-email": "[email protected]",
"working-dir": "tmp",
"working-dir": ".",
"file-patterns": ".",
}

Expand All @@ -54,12 +54,13 @@ def test_no_ssp_index(valid_args_dict: Dict[str, str], caplog: Any) -> None:
args_dict["oscal-model"] = "ssp"
args_dict["ssp-index-path"] = ""
with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]):
with pytest.raises(SystemExit):
with pytest.raises(SystemExit, match="2"):
cli_main()

assert any(
record.levelno == logging.ERROR
and record.message == "Must set ssp_index_path when using SSP as oscal model."
and "Invalid args --ssp-index-path: Must set ssp index path when using SSP as "
"oscal model." in record.message
for record in caplog.records
)

Expand All @@ -69,12 +70,27 @@ def test_no_markdown_path(valid_args_dict: Dict[str, str], caplog: Any) -> None:
args_dict = valid_args_dict
args_dict["markdown-path"] = ""
with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]):
with pytest.raises(SystemExit):
with pytest.raises(SystemExit, match="2"):
cli_main()

assert any(
record.levelno == logging.ERROR
and record.message == "Must set markdown path with oscal model."
and "Invalid args --markdown-path: Markdown path must be set." in record.message
for record in caplog.records
)


def test_non_existent_working_dir(valid_args_dict: Dict[str, str], caplog: Any) -> None:
"""Test with a non-existent working directory"""
args_dict = valid_args_dict
args_dict["working-dir"] = "tmp"
with patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]):
with pytest.raises(SystemExit, match="1"):
cli_main()

assert any(
record.levelno == logging.ERROR
and "Root path tmp does not exist" in record.message
for record in caplog.records
)

Expand All @@ -91,13 +107,14 @@ def test_with_target_branch(valid_args_dict: Dict[str, str], caplog: Any) -> Non
) as mock_check, patch("sys.argv", ["trestlebot", *args_dict_to_list(args_dict)]):
mock_check.return_value = False

with pytest.raises(SystemExit):
with pytest.raises(SystemExit, match="2"):
cli_main()

assert any(
record.levelno == logging.ERROR
and record.message == "target-branch flag is set with an unset git provider. "
"To test locally, set the GITHUB_ACTIONS or GITLAB_CI environment variable."
and "Invalid args --target-branch: target-branch flag is set with an "
"unset git provider. To test locally, set the GITHUB_ACTIONS or GITLAB_CI environment variable."
in record.message
for record in caplog.records
)

Expand Down
103 changes: 58 additions & 45 deletions trestlebot/entrypoints/autosync.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@
import sys
from typing import List

from trestlebot import const
from trestlebot.entrypoints.entrypoint_base import EntrypointBase, comma_sep_to_list
from trestlebot.const import SUCCESS_EXIT_CODE
from trestlebot.entrypoints.entrypoint_base import (
EntrypointBase,
EntrypointInvalidArgException,
comma_sep_to_list,
handle_exception,
)
from trestlebot.entrypoints.log import set_log_level_from_args
from trestlebot.tasks.assemble_task import AssembleTask
from trestlebot.tasks.authored import types
Expand Down Expand Up @@ -95,63 +100,71 @@ def validate_args(self, args: argparse.Namespace) -> None:
"""Validate the arguments for the autosync entrypoint."""
authored_list: List[str] = [model.value for model in types.AuthoredType]
if args.oscal_model not in authored_list:
logger.error(
f"Invalid value {args.oscal_model} for oscal model. "
f"Please use one of {authored_list}"
raise EntrypointInvalidArgException(
"--oscal-model",
f"Value {args.oscal_model} is not valid."
f"Please use one of {authored_list}",
)
sys.exit(const.ERROR_EXIT_CODE)

if not args.markdown_path:
logger.error("Must set markdown path with oscal model.")
sys.exit(const.ERROR_EXIT_CODE)
raise EntrypointInvalidArgException(
"--markdown-path", "Markdown path must be set."
)

if (
args.oscal_model == types.AuthoredType.SSP.value
and args.ssp_index_path == ""
):
logger.error("Must set ssp_index_path when using SSP as oscal model.")
sys.exit(const.ERROR_EXIT_CODE)
raise EntrypointInvalidArgException(
"--ssp-index-path",
"Must set ssp index path when using SSP as oscal model.",
)

def run(self, args: argparse.Namespace) -> None:
"""Run the autosync entrypoint."""

set_log_level_from_args(args)
self.validate_args(args)

pre_tasks: List[TaskBase] = []
# Allow any model to be skipped from the args, by default include all
model_filter: ModelFilter = ModelFilter(
skip_patterns=comma_sep_to_list(args.skip_items),
include_patterns=["*"],
)

authored_object: AuthoredObjectBase = types.get_authored_object(
args.oscal_model, args.working_dir, args.ssp_index_path
)

# Assuming an edit has occurred assemble would be run before regenerate.
# Adding this to the list first
if not args.skip_assemble:
assemble_task = AssembleTask(
authored_object=authored_object,
markdown_dir=args.markdown_path,
model_filter=model_filter,
exit_code: int = SUCCESS_EXIT_CODE
try:
set_log_level_from_args(args)
self.validate_args(args)

pre_tasks: List[TaskBase] = []
# Allow any model to be skipped from the args, by default include all
model_filter: ModelFilter = ModelFilter(
skip_patterns=comma_sep_to_list(args.skip_items),
include_patterns=["*"],
)
pre_tasks.append(assemble_task)
else:
logger.info("Assemble task skipped.")

if not args.skip_regenerate:
regenerate_task = RegenerateTask(
authored_object=authored_object,
markdown_dir=args.markdown_path,
model_filter=model_filter,

authored_object: AuthoredObjectBase = types.get_authored_object(
args.oscal_model, args.working_dir, args.ssp_index_path
)
pre_tasks.append(regenerate_task)
else:
logger.info("Regeneration task skipped.")

super().run_base(args, pre_tasks)
# Assuming an edit has occurred assemble would be run before regenerate.
# Adding this to the list first
if not args.skip_assemble:
assemble_task: AssembleTask = AssembleTask(
authored_object=authored_object,
markdown_dir=args.markdown_path,
model_filter=model_filter,
)
pre_tasks.append(assemble_task)
else:
logger.info("Assemble task skipped.")

if not args.skip_regenerate:
regenerate_task: RegenerateTask = RegenerateTask(
authored_object=authored_object,
markdown_dir=args.markdown_path,
model_filter=model_filter,
)
pre_tasks.append(regenerate_task)
else:
logger.info("Regeneration task skipped.")

super().run_base(args, pre_tasks)
except Exception as e:
exit_code = handle_exception(e)

sys.exit(exit_code)


def main() -> None:
Expand Down
100 changes: 55 additions & 45 deletions trestlebot/entrypoints/create_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
import argparse
import logging
import pathlib
import sys
from typing import List, Optional

from trestlebot.const import RULE_PREFIX, RULES_VIEW_DIR
from trestlebot.entrypoints.entrypoint_base import EntrypointBase
from trestlebot.const import RULE_PREFIX, RULES_VIEW_DIR, SUCCESS_EXIT_CODE
from trestlebot.entrypoints.entrypoint_base import EntrypointBase, handle_exception
from trestlebot.entrypoints.log import set_log_level_from_args
from trestlebot.tasks.authored.compdef import (
AuthoredComponentDefinition,
Expand Down Expand Up @@ -91,49 +92,58 @@ def setup_create_cd_arguments(self) -> None:

def run(self, args: argparse.Namespace) -> None:
"""Run the entrypoint."""

set_log_level_from_args(args)
pre_tasks: List[TaskBase] = []
filter_by_profile: Optional[FilterByProfile] = None
trestle_root: pathlib.Path = pathlib.Path(args.working_dir)

if args.filter_by_profile:
filter_by_profile = FilterByProfile(trestle_root, args.filter_by_profile)

authored_comp = AuthoredComponentDefinition(args.working_dir)
authored_comp.create_new_default(
args.profile_name,
args.compdef_name,
args.component_title,
args.component_description,
args.component_definition_type,
filter_by_profile,
)

transformer: ToRulesYAMLTransformer = ToRulesYAMLTransformer()

# In this case we only want to do the transformation and generation for this component
# definition, so we skip all other component definitions and components.
model_filter: ModelFilter = ModelFilter(
[], [args.compdef_name, args.component_title, f"{RULE_PREFIX}*"]
)

rule_transform_task: RuleTransformTask = RuleTransformTask(
working_dir=args.working_dir,
rules_view_dir=RULES_VIEW_DIR,
rule_transformer=transformer,
model_filter=model_filter,
)
pre_tasks.append(rule_transform_task)

regenerate_task = RegenerateTask(
authored_object=authored_comp,
markdown_dir=args.markdown_path,
model_filter=model_filter,
)
pre_tasks.append(regenerate_task)

super().run_base(args, pre_tasks)
exit_code: int = SUCCESS_EXIT_CODE
try:
set_log_level_from_args(args)
pre_tasks: List[TaskBase] = []
filter_by_profile: Optional[FilterByProfile] = None
trestle_root: pathlib.Path = pathlib.Path(args.working_dir)

if args.filter_by_profile:
filter_by_profile = FilterByProfile(
trestle_root, args.filter_by_profile
)

authored_comp: AuthoredComponentDefinition = AuthoredComponentDefinition(
args.working_dir
)
authored_comp.create_new_default(
args.profile_name,
args.compdef_name,
args.component_title,
args.component_description,
args.component_definition_type,
filter_by_profile,
)

transformer: ToRulesYAMLTransformer = ToRulesYAMLTransformer()

# In this case we only want to do the transformation and generation for this component
# definition, so we skip all other component definitions and components.
model_filter: ModelFilter = ModelFilter(
[], [args.compdef_name, args.component_title, f"{RULE_PREFIX}*"]
)

rule_transform_task: RuleTransformTask = RuleTransformTask(
working_dir=args.working_dir,
rules_view_dir=RULES_VIEW_DIR,
rule_transformer=transformer,
model_filter=model_filter,
)
pre_tasks.append(rule_transform_task)

regenerate_task: RegenerateTask = RegenerateTask(
authored_object=authored_comp,
markdown_dir=args.markdown_path,
model_filter=model_filter,
)
pre_tasks.append(regenerate_task)

super().run_base(args, pre_tasks)
except Exception as e:
exit_code = handle_exception(e)

sys.exit(exit_code)


def main() -> None:
Expand Down
Loading

0 comments on commit 4ca6958

Please sign in to comment.