diff --git a/pyproject.toml b/pyproject.toml index acebb6bb..7cde8747 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ repository = 'https://github.com/RedHatProductSecurity/trestle-bot' [tool.poetry.scripts] trestlebot-autosync = "trestlebot.entrypoints.autosync:main" trestlebot-rules-transform = "trestlebot.entrypoints.rule_transform:main" +trestlebot-create-cd = "trestlebot.entrypoints.create_cd:main" [tool.poetry.dependencies] python = '^3.8.1' diff --git a/tests/trestlebot/tasks/authored/test_compdef.py b/tests/trestlebot/tasks/authored/test_compdef.py index 42f82143..e90de24d 100644 --- a/tests/trestlebot/tasks/authored/test_compdef.py +++ b/tests/trestlebot/tasks/authored/test_compdef.py @@ -20,19 +20,21 @@ import re import pytest -from trestle.common.model_utils import ModelUtils -from trestle.core.catalog.catalog_interface import CatalogInterface -from trestle.core.profile_resolver import ProfileResolver -from trestle.oscal import profile as prof +from trestle.common.err import TrestleError +from trestle.oscal.profile import Profile from tests import testutils from trestlebot.const import RULES_VIEW_DIR, YAML_EXTENSION from trestlebot.tasks.authored.base_authored import AuthoredObjectException -from trestlebot.tasks.authored.compdef import AuthoredComponentDefinition +from trestlebot.tasks.authored.compdef import ( + AuthoredComponentDefinition, + FilterByProfile, +) from trestlebot.transformers.yaml_transformer import ToRulesYAMLTransformer test_prof = "simplified_nist_profile" +test_filter_prof = "simplified_filter_profile" test_comp = "test_comp" @@ -89,24 +91,17 @@ def test_create_new_default(tmp_trestle_dir: str) -> None: def test_create_new_default_with_filter(tmp_trestle_dir: str) -> None: """Test creating new default component definition with filter""" + # Prepare the workspace trestle_root = pathlib.Path(tmp_trestle_dir) _ = testutils.setup_for_profile(trestle_root, test_prof, "") + testutils.load_from_json(trestle_root, test_filter_prof, test_filter_prof, Profile) authored_comp = AuthoredComponentDefinition(tmp_trestle_dir) - profile_path = ModelUtils.get_model_path_for_name_and_class( - trestle_root, test_prof, prof.Profile - ) - - catalog = ProfileResolver.get_resolved_profile_catalog( - trestle_root, profile_path=profile_path - ) - - catalog_interface = CatalogInterface(catalog) - catalog_interface.delete_control("ac-5") + filter_by_profile = FilterByProfile(trestle_root, test_filter_prof) authored_comp.create_new_default( - test_prof, test_comp, "test", "My desc", "service", catalog_interface + test_prof, test_comp, "test", "My desc", "service", filter_by_profile ) rules_view_dir = trestle_root / RULES_VIEW_DIR @@ -119,9 +114,9 @@ def test_create_new_default_with_filter(tmp_trestle_dir: str) -> None: assert comp_dir.exists() # Verity that the number of rules YAML files has been reduced - # from 12 to 11. + # from 12 to 7. yaml_files = list(comp_dir.glob(f"*{YAML_EXTENSION}")) - assert len(yaml_files) == 11 + assert len(yaml_files) == 7 def test_create_new_default_no_profile(tmp_trestle_dir: str) -> None: @@ -138,3 +133,13 @@ def test_create_new_default_no_profile(tmp_trestle_dir: str) -> None: authored_comp.create_new_default( "fake", test_comp, "test", "My desc", "service" ) + + +def test_filter_by_profile_with_no_profile(tmp_trestle_dir: str) -> None: + """Test creating a profile filter with a non-existent profile""" + trestle_root = pathlib.Path(tmp_trestle_dir) + + with pytest.raises( + TrestleError, match="Profile fake does not exist in the workspace" + ): + _ = FilterByProfile(trestle_root, "fake") diff --git a/tests/trestlebot/tasks/test_assemble_task.py b/tests/trestlebot/tasks/test_assemble_task.py index 711f9870..5c3bfa80 100644 --- a/tests/trestlebot/tasks/test_assemble_task.py +++ b/tests/trestlebot/tasks/test_assemble_task.py @@ -36,6 +36,7 @@ from trestlebot.tasks.assemble_task import AssembleTask from trestlebot.tasks.authored.base_authored import AuthorObjectBase from trestlebot.tasks.authored.types import AuthoredType +from trestlebot.tasks.base_task import ModelFilter test_prof = "simplified_nist_profile" @@ -95,11 +96,13 @@ def test_assemble_task_with_skip(tmp_trestle_dir: str, skip_list: List[str]) -> mock = Mock(spec=AuthorObjectBase) + filter = ModelFilter(skip_list, ["."]) + assemble_task = AssembleTask( working_dir=tmp_trestle_dir, authored_model=AuthoredType.CATALOG.value, markdown_dir=cat_md_dir, - skip_model_list=skip_list, + filter=filter, ) with patch( diff --git a/tests/trestlebot/tasks/test_base_task.py b/tests/trestlebot/tasks/test_base_task.py new file mode 100644 index 00000000..b2887c3c --- /dev/null +++ b/tests/trestlebot/tasks/test_base_task.py @@ -0,0 +1,47 @@ +#!/usr/bin/python + +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Test workspace filtering logic.""" + +import pathlib +from typing import List + +import pytest + +from trestlebot.tasks.base_task import ModelFilter + + +@pytest.mark.parametrize( + "skip_list, include_list, model_name, expected", + [ + [["simplified_nist_catalog"], [], "simplified_nist_catalog", True], + [[], ["simplified_nist_catalog"], "simplified_nist_catalog", False], + [["simplified*"], ["."], "simplified_nist_catalog", True], + [ + ["simplified_nist_catalog"], + ["simplified*"], + "simplified_nist_profile", + False, + ], + ], +) +def test_is_skipped( + skip_list: List[str], include_list: List[str], model_name: str, expected: str +) -> None: + """Test skip logic.""" + model_path = pathlib.Path(model_name) + model_filter = ModelFilter(skip_list, include_list) + assert model_filter.is_skipped(model_path) == expected diff --git a/tests/trestlebot/tasks/test_regenerate_task.py b/tests/trestlebot/tasks/test_regenerate_task.py index 79ede3be..6000b107 100644 --- a/tests/trestlebot/tasks/test_regenerate_task.py +++ b/tests/trestlebot/tasks/test_regenerate_task.py @@ -28,6 +28,7 @@ from tests import testutils from trestlebot.tasks.authored.base_authored import AuthorObjectBase from trestlebot.tasks.authored.types import AuthoredType +from trestlebot.tasks.base_task import ModelFilter from trestlebot.tasks.regenerate_task import RegenerateTask @@ -86,11 +87,13 @@ def test_regenerate_task_with_skip(tmp_trestle_dir: str, skip_list: List[str]) - mock = Mock(spec=AuthorObjectBase) + filter = ModelFilter(skip_list, ["."]) + regenerate_task = RegenerateTask( working_dir=tmp_trestle_dir, authored_model=AuthoredType.CATALOG.value, markdown_dir=cat_md_dir, - skip_model_list=skip_list, + filter=filter, ) with patch( @@ -119,19 +122,6 @@ def test_catalog_regenerate_task(tmp_trestle_dir: str) -> None: assert os.path.exists(os.path.join(tmp_trestle_dir, md_path)) -def test_catalog_regenerate_task_with_skip(tmp_trestle_dir: str) -> None: - """Test catalog regenerate at the task level""" - trestle_root = pathlib.Path(tmp_trestle_dir) - md_path = os.path.join(cat_md_dir, test_cat) - _ = testutils.setup_for_catalog(trestle_root, test_cat, md_path) - - regenerate_task = RegenerateTask( - tmp_trestle_dir, AuthoredType.CATALOG.value, cat_md_dir, "", [test_cat] - ) - assert regenerate_task.execute() == 0 - assert not os.path.exists(os.path.join(tmp_trestle_dir, md_path)) - - def test_profile_regenerate_task(tmp_trestle_dir: str) -> None: """Test profile regenerate at the task level""" trestle_root = pathlib.Path(tmp_trestle_dir) diff --git a/tests/trestlebot/tasks/test_rule_transform_task.py b/tests/trestlebot/tasks/test_rule_transform_task.py index 3c1e0c9d..dd55f3d1 100644 --- a/tests/trestlebot/tasks/test_rule_transform_task.py +++ b/tests/trestlebot/tasks/test_rule_transform_task.py @@ -14,7 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -"""Test for Trestle Bot rule transform task""" +"""Test for Trestle Bot rule transform task.""" import pathlib @@ -26,7 +26,7 @@ from trestle.tasks.csv_to_oscal_cd import RULE_DESCRIPTION, RULE_ID from tests.testutils import setup_rules_view -from trestlebot.tasks.base_task import TaskException +from trestlebot.tasks.base_task import ModelFilter, TaskException from trestlebot.tasks.rule_transform_task import RuleTransformTask from trestlebot.transformers.yaml_transformer import ToRulesYAMLTransformer @@ -116,8 +116,10 @@ def test_rule_transform_task_with_skip(tmp_trestle_dir: str) -> None: trestle_root = pathlib.Path(tmp_trestle_dir) setup_rules_view(trestle_root, test_comp, test_rules_dir) transformer = ToRulesYAMLTransformer() + + filter = ModelFilter([test_comp], []) rule_transform_task = RuleTransformTask( - tmp_trestle_dir, test_rules_dir, transformer, skip_model_list=[test_comp] + tmp_trestle_dir, test_rules_dir, transformer, filter=filter ) return_code = rule_transform_task.execute() assert return_code == 0 diff --git a/trestlebot/const.py b/trestlebot/const.py index f7ad568a..32f8dbaf 100644 --- a/trestlebot/const.py +++ b/trestlebot/const.py @@ -46,3 +46,4 @@ YAML_EXTENSION = ".yaml" RULES_VIEW_DIR = "rules" +RULE_PREFIX = "rule-" diff --git a/trestlebot/entrypoints/autosync.py b/trestlebot/entrypoints/autosync.py index ee85594d..3f638976 100644 --- a/trestlebot/entrypoints/autosync.py +++ b/trestlebot/entrypoints/autosync.py @@ -32,7 +32,7 @@ from trestlebot.entrypoints.log import set_log_level_from_args from trestlebot.tasks.assemble_task import AssembleTask from trestlebot.tasks.authored import types -from trestlebot.tasks.base_task import TaskBase +from trestlebot.tasks.base_task import ModelFilter, TaskBase from trestlebot.tasks.regenerate_task import RegenerateTask @@ -113,35 +113,43 @@ def run(self, args: argparse.Namespace) -> None: logger.error("Must set markdown path with oscal model.") sys.exit(const.ERROR_EXIT_CODE) - if args.oscal_model == "ssp" and args.ssp_index_path == "": + if ( + args.oscal_model == types.AuthoredType.SSP.value + and args.ssp_index_path == "" + ): logger.error("Must set ssp_index_path when using SSP as oscal model.") sys.exit(const.ERROR_EXIT_CODE) + filter: ModelFilter = ModelFilter( + skip_patterns=comma_sep_to_list(args.skip_items), + include_patterns=["."], + ) + # Assuming an edit has occurred assemble would be run before regenerate. # Adding this to the list first if not args.skip_assemble: assemble_task = AssembleTask( - args.working_dir, - args.oscal_model, - args.markdown_path, - args.ssp_index_path, - comma_sep_to_list(args.skip_items), + working_dir=args.working_dir, + authored_model=args.oscal_model, + markdown_dir=args.markdown_path, + ssp_index_path=args.ssp_index_path, + filter=filter, ) pre_tasks.append(assemble_task) else: - logger.info("Assemble task skipped") + logger.info("Assemble task skipped.") if not args.skip_regenerate: regenerate_task = RegenerateTask( - args.working_dir, - args.oscal_model, - args.markdown_path, - args.ssp_index_path, - comma_sep_to_list(args.skip_items), + working_dir=args.working_dir, + authored_model=args.oscal_model, + markdown_dir=args.markdown_path, + ssp_index_path=args.ssp_index_path, + filter=filter, ) pre_tasks.append(regenerate_task) else: - logger.info("Regeneration task skipped") + logger.info("Regeneration task skipped.") super().run_base(args, pre_tasks) diff --git a/trestlebot/entrypoints/create_cd.py b/trestlebot/entrypoints/create_cd.py new file mode 100644 index 00000000..b46093b1 --- /dev/null +++ b/trestlebot/entrypoints/create_cd.py @@ -0,0 +1,153 @@ +#!/usr/bin/python + +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Entrypoint for component definition bootstrapping. + +This will create a rules-view directory in the working directory, create a component +definition in JSON format with the initially generated rules and initial trestle control markdown. +""" + +import argparse +import logging +import pathlib +from typing import List, Optional + +from trestlebot.const import RULE_PREFIX, RULES_VIEW_DIR +from trestlebot.entrypoints.entrypoint_base import EntrypointBase +from trestlebot.entrypoints.log import set_log_level_from_args +from trestlebot.tasks.authored.compdef import ( + AuthoredComponentDefinition, + FilterByProfile, +) +from trestlebot.tasks.authored.types import AuthoredType +from trestlebot.tasks.base_task import ModelFilter, TaskBase +from trestlebot.tasks.regenerate_task import RegenerateTask +from trestlebot.tasks.rule_transform_task import RuleTransformTask +from trestlebot.transformers.yaml_transformer import ToRulesYAMLTransformer + + +logger = logging.getLogger(__name__) + + +class CreateCDEntrypoint(EntrypointBase): + """Entrypoint for component definition bootstrapping.""" + + def __init__(self, parser: argparse.ArgumentParser) -> None: + """Initialize.""" + super().__init__(parser) + self.setup_create_cd_arguments() + + def setup_create_cd_arguments(self) -> None: + """Setup specific arguments for this entrypoint.""" + self.parser.add_argument( + "--profile-name", + required=True, + help="Name of profile in the trestle workspace to use with the component definition.", + ) + self.parser.add_argument( + "--compdef-name", required=True, help="Name of component definition" + ) + self.parser.add_argument( + "--component-title", required=True, help="Title of initial component" + ) + self.parser.add_argument( + "--component-description", + required=True, + help="Description of initial component", + ) + self.parser.add_argument( + "--markdown-path", + required=True, + type=str, + help="Path to create markdown files in.", + ) + self.parser.add_argument( + "--component-definition-type", + required=False, + type=str, + choices=["service", "validation"], + default="service", + help="Type of component definition", + ) + self.parser.add_argument( + "--filter-by-profile", + required=False, + type=str, + help="Optionally filter the controls in the component definition by a profile.", + ) + + def run(self, args: argparse.Namespace) -> None: + """Run the entrypoint.""" + + set_log_level_from_args(args) + pre_tasks: List[TaskBase] = [] + filter_by_profile: Optional[FilterByProfile] = None + trestle_root: pathlib.Path = pathlib.Path(args.working_dir) + + if args.filter_by_profile: + filter_by_profile = FilterByProfile(trestle_root, args.filter_by_profile) + + authored_comp = AuthoredComponentDefinition(args.working_dir) + authored_comp.create_new_default( + args.profile_name, + args.compdef_name, + args.component_title, + args.component_description, + args.component_definition_type, + filter_by_profile, + ) + + transformer: ToRulesYAMLTransformer = ToRulesYAMLTransformer() + + # In this case we only want to do the transformation and generation for this component + # definition, so we skip all other component definitions and components. + workspace_filter: ModelFilter = ModelFilter( + [], [args.compdef_name, args.component_title, f"{RULE_PREFIX}*"] + ) + + rule_transform_task: RuleTransformTask = RuleTransformTask( + working_dir=args.working_dir, + rules_view_dir=RULES_VIEW_DIR, + rule_transformer=transformer, + filter=workspace_filter, + ) + pre_tasks.append(rule_transform_task) + + regenerate_task = RegenerateTask( + working_dir=args.working_dir, + authored_model=AuthoredType.COMPDEF.value, + markdown_dir=args.markdown_path, + filter=workspace_filter, + ) + pre_tasks.append(regenerate_task) + + super().run_base(args, pre_tasks) + + +def main() -> None: + """Run the CLI.""" + parser = argparse.ArgumentParser( + description="Create new component definition with defaults." + ) + set_default_component_fields = CreateCDEntrypoint(parser=parser) + + args = parser.parse_args() + set_default_component_fields.run(args) + + +if __name__ == "__main__": + main() diff --git a/trestlebot/entrypoints/entrypoint_base.py b/trestlebot/entrypoints/entrypoint_base.py index 771f71da..24434997 100644 --- a/trestlebot/entrypoints/entrypoint_base.py +++ b/trestlebot/entrypoints/entrypoint_base.py @@ -73,8 +73,9 @@ def setup_common_arguments(self) -> None: ) self.parser.add_argument( "--file-patterns", - required=True, + required=False, type=str, + default=".", help="Comma-separated list of file patterns to be used with `git add` in repository updates", ) self.parser.add_argument( diff --git a/trestlebot/entrypoints/rule_transform.py b/trestlebot/entrypoints/rule_transform.py index 089c4d05..15eee2e8 100644 --- a/trestlebot/entrypoints/rule_transform.py +++ b/trestlebot/entrypoints/rule_transform.py @@ -20,7 +20,7 @@ from trestlebot.entrypoints.entrypoint_base import EntrypointBase, comma_sep_to_list from trestlebot.entrypoints.log import set_log_level_from_args -from trestlebot.tasks.base_task import TaskBase +from trestlebot.tasks.base_task import ModelFilter, TaskBase from trestlebot.tasks.rule_transform_task import RuleTransformTask from trestlebot.transformers.validations import ValidationHandler, parameter_validation from trestlebot.transformers.yaml_transformer import ToRulesYAMLTransformer @@ -63,11 +63,16 @@ def run(self, args: argparse.Namespace) -> None: validation_handler: ValidationHandler = ValidationHandler(parameter_validation) transformer: ToRulesYAMLTransformer = ToRulesYAMLTransformer(validation_handler) + filter: ModelFilter = ModelFilter( + skip_patterns=comma_sep_to_list(args.skip_items), + include_patterns=["."], + ) + rule_transform_task: RuleTransformTask = RuleTransformTask( - args.working_dir, - args.rules_view_path, - transformer, - comma_sep_to_list(args.skip_items), + working_dir=args.working_dir, + rules_view_dir=args.rules_view_path, + rule_transformer=transformer, + filter=filter, ) pre_tasks: List[TaskBase] = [rule_transform_task] diff --git a/trestlebot/tasks/assemble_task.py b/trestlebot/tasks/assemble_task.py index 04183f9a..a2f13d10 100644 --- a/trestlebot/tasks/assemble_task.py +++ b/trestlebot/tasks/assemble_task.py @@ -18,7 +18,7 @@ import os import pathlib -from typing import List +from typing import Optional from trestlebot import const from trestlebot.tasks.authored import types @@ -26,7 +26,7 @@ AuthoredObjectException, AuthorObjectBase, ) -from trestlebot.tasks.base_task import TaskBase, TaskException +from trestlebot.tasks.base_task import ModelFilter, TaskBase, TaskException class AssembleTask(TaskBase): @@ -40,7 +40,7 @@ def __init__( authored_model: str, markdown_dir: str, ssp_index_path: str = "", - skip_model_list: List[str] = [], + filter: Optional[ModelFilter] = None, ) -> None: """ Initialize assemble task. @@ -50,13 +50,13 @@ def __init__( authored_model: String representation of model type markdown_dir: Location of directory to write Markdown in ssp_index_path: Path of ssp index JSON in the workspace - skip_model_list: List of model names to be skipped during processing + filter: Optional filter to apply to the task to include or exclude models from processing """ self._authored_model = authored_model self._markdown_dir = markdown_dir self._ssp_index_path = ssp_index_path - super().__init__(working_dir, skip_model_list) + super().__init__(working_dir, filter) def execute(self) -> int: """Execute task""" diff --git a/trestlebot/tasks/authored/compdef.py b/trestlebot/tasks/authored/compdef.py index b4966d54..29bd3af1 100644 --- a/trestlebot/tasks/authored/compdef.py +++ b/trestlebot/tasks/authored/compdef.py @@ -28,7 +28,7 @@ from trestle.core.profile_resolver import ProfileResolver from trestle.core.repository import AgileAuthoring -from trestlebot.const import RULES_VIEW_DIR, YAML_EXTENSION +from trestlebot.const import RULE_PREFIX, RULES_VIEW_DIR, YAML_EXTENSION from trestlebot.tasks.authored.base_authored import ( AuthoredObjectException, AuthorObjectBase, @@ -42,6 +42,30 @@ from trestlebot.transformers.yaml_transformer import FromRulesYAMLTransformer +class FilterByProfile: + """Filter controls by a profile.""" + + def __init__(self, trestle_root: pathlib.Path, profile_name: str) -> None: + """Initialize.""" + filter_profile_path = ModelUtils.get_model_path_for_name_and_class( + trestle_root, profile_name, prof.Profile + ) + + if filter_profile_path is None: + raise TrestleError( + f"Profile {profile_name} does not exist in the workspace" + ) + + catalog = ProfileResolver.get_resolved_profile_catalog( + trestle_root, filter_profile_path + ) + self._control_ids = CatalogInterface(catalog).get_control_ids() + + def __call__(self, control_id: str) -> bool: + """Filter controls by catalog.""" + return control_id in self._control_ids + + class AuthoredComponentDefinition(AuthorObjectBase): """ Class for authoring OSCAL Component Definitions in automation @@ -103,7 +127,7 @@ def create_new_default( comp_title: str, comp_description: str, comp_type: str, - filter_controls: Optional[CatalogInterface] = None, + filter_by_profile: Optional[FilterByProfile] = None, ) -> None: """ Create the new component definition with default info. @@ -114,7 +138,8 @@ def create_new_default( comp_title: Title of the component comp_description: Description of the component comp_type: Type of the component - filter_controls: Optional catalog to filter the controls to include from the profile + filter_by_profile: Optional filter to use for the component definition control + implementation controls Notes: The beginning of the Component Definition workflow is to create a new @@ -128,7 +153,7 @@ def create_new_default( if existing_profile_path is None: raise AuthoredObjectException( - f"Profile {profile_name} does not exist in the workspace" + f"Profile {profile_name} does not exist in the workspace." ) rule_dir: pathlib.Path = trestle_root.joinpath(RULES_VIEW_DIR, compdef_name) @@ -140,28 +165,12 @@ def create_new_default( rules_view_builder = RulesViewBuilder(trestle_root) - filter_func: Optional[Callable[[str], bool]] = None - if filter_controls is not None: - filter_func = FilterByCatalog(filter_controls) - rules_view_builder.add_rules_for_profile( - existing_profile_path, component_info, filter_func + existing_profile_path, component_info, filter_by_profile ) rules_view_builder.write_to_yaml(rule_dir) -class FilterByCatalog: - """Filter controls by catalog.""" - - def __init__(self, catalog: CatalogInterface) -> None: - """Initialize.""" - self._catalog = catalog - - def __call__(self, control_id: str) -> bool: - """Filter controls by catalog.""" - return control_id in self._catalog.get_control_ids() - - class RulesViewBuilder: """Write TrestleRule objects to YAML files in rules view.""" @@ -197,7 +206,7 @@ def add_rules_for_profile( rule = TrestleRule( component=component_info, - name=f"rule-{control_id}", + name=f"{RULE_PREFIX}{control_id}", description=f"Rule for {control_id}", profile=Profile( href=const.TRESTLE_HREF_HEADING diff --git a/trestlebot/tasks/base_task.py b/trestlebot/tasks/base_task.py index 059c969f..6624d292 100644 --- a/trestlebot/tasks/base_task.py +++ b/trestlebot/tasks/base_task.py @@ -14,12 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -"""Trestle Bot base task for extendable bot pre-tasks""" +"""Trestle Bot base task for extensible bot pre-tasks""" import fnmatch import pathlib from abc import ABC, abstractmethod -from typing import Iterable, List +from typing import Callable, Iterable, List, Optional from trestle.common import const from trestle.common.file_utils import is_hidden @@ -29,22 +29,53 @@ class TaskException(Exception): """An error during task execution""" +class ModelFilter: + """ + Filter models based on include and exclude patterns. + + Args: + skip_patterns: List of glob patterns to exclude from processing. + include_patterns: List of glob patterns to include in processing. + + Note: If a model is in both the include and exclude lists, it will be excluded. + The skip list is applied first. + """ + + def __init__(self, skip_patterns: List[str], include_patterns: List[str]): + self._include_model_list: List[str] = include_patterns + self._skip_model_list: List[str] = [const.TRESTLE_KEEP_FILE] + skip_patterns + + def is_skipped(self, model_path: pathlib.Path) -> bool: + """Check if the model is skipped through include or skip lists.""" + if any( + fnmatch.fnmatch(model_path.name, pattern) + for pattern in self._skip_model_list + ): + return True + elif any( + fnmatch.fnmatch(model_path.name, pattern) + for pattern in self._include_model_list + ): + return False + else: + return True + + class TaskBase(ABC): """ Abstract base class for tasks with a work directory. """ - def __init__(self, working_dir: str, skip_list: List[str]) -> None: + def __init__(self, working_dir: str, filter: Optional[ModelFilter]) -> None: """ Initialize base task. Args: working_dir: Working directory to complete operations in. - skip_list: List of glob patterns to be skipped during processing. + filter: Model filter to use for this task. """ self._working_dir = working_dir - self._skip_model_list = skip_list - self._skip_model_list.append(const.TRESTLE_KEEP_FILE) + self.filter: Optional[ModelFilter] = filter @property def working_dir(self) -> str: @@ -53,22 +84,27 @@ def working_dir(self) -> str: def iterate_models(self, directory_path: pathlib.Path) -> Iterable[pathlib.Path]: """Iterate over the models in the working directory""" - filtered_paths = list( - filter( - lambda p: not self._is_skipped(p.name) - and (not is_hidden(p) or p.is_dir()), - pathlib.Path.iterdir(directory_path), + filtered_paths: Iterable[pathlib.Path] + + if self.filter is not None: + is_skipped: Callable[[pathlib.Path], bool] = self.filter.is_skipped + filtered_paths = list( + filter( + lambda p: not is_skipped(p) and (not is_hidden(p) or p.is_dir()), + pathlib.Path.iterdir(directory_path), + ) + ) + else: + filtered_paths = list( + filter( + lambda p: not is_hidden(p) or p.is_dir(), + pathlib.Path.iterdir(directory_path), + ) ) - ) return filtered_paths.__iter__() - def _is_skipped(self, model_name: str) -> bool: - """Return True if the model is in the skip list""" - return any( - fnmatch.fnmatch(model_name, pattern) for pattern in self._skip_model_list - ) - @abstractmethod def execute(self) -> int: """Execute the task and return the exit code""" + pass diff --git a/trestlebot/tasks/regenerate_task.py b/trestlebot/tasks/regenerate_task.py index 9494b214..76b51402 100644 --- a/trestlebot/tasks/regenerate_task.py +++ b/trestlebot/tasks/regenerate_task.py @@ -18,7 +18,7 @@ import os import pathlib -from typing import List +from typing import Optional from trestlebot import const from trestlebot.tasks.authored import types @@ -26,7 +26,7 @@ AuthoredObjectException, AuthorObjectBase, ) -from trestlebot.tasks.base_task import TaskBase, TaskException +from trestlebot.tasks.base_task import ModelFilter, TaskBase, TaskException class RegenerateTask(TaskBase): @@ -40,7 +40,7 @@ def __init__( authored_model: str, markdown_dir: str, ssp_index_path: str = "", - skip_model_list: List[str] = [], + filter: Optional[ModelFilter] = None, ) -> None: """ Initialize regenerate task. @@ -50,13 +50,13 @@ def __init__( authored_model: String representation of model type markdown_dir: Location of directory to write Markdown in ssp_index_path: Path of ssp index JSON in the workspace - skip_model_list: List of model names to be skipped during processing + filter: Optional filter to apply to the task to include or exclude models from processing. """ self._authored_model = authored_model self._markdown_dir = markdown_dir self._ssp_index_path = ssp_index_path - super().__init__(working_dir, skip_model_list) + super().__init__(working_dir, filter) def execute(self) -> int: """Execute task""" diff --git a/trestlebot/tasks/rule_transform_task.py b/trestlebot/tasks/rule_transform_task.py index 71003efa..0535b3d3 100644 --- a/trestlebot/tasks/rule_transform_task.py +++ b/trestlebot/tasks/rule_transform_task.py @@ -20,14 +20,14 @@ import logging import os import pathlib -from typing import List +from typing import List, Optional import trestle.common.const as trestle_const from trestle.tasks.base_task import TaskOutcome from trestle.tasks.csv_to_oscal_cd import CsvToOscalComponentDefinition import trestlebot.const as const -from trestlebot.tasks.base_task import TaskBase, TaskException +from trestlebot.tasks.base_task import ModelFilter, TaskBase, TaskException from trestlebot.transformers.base_transformer import RulesTransformerException from trestlebot.transformers.csv_transformer import CSVBuilder from trestlebot.transformers.yaml_transformer import ToRulesYAMLTransformer @@ -46,7 +46,7 @@ def __init__( working_dir: str, rules_view_dir: str, rule_transformer: ToRulesYAMLTransformer, - skip_model_list: List[str] = [], + filter: Optional[ModelFilter] = None, ) -> None: """ Initialize transform task. @@ -55,7 +55,7 @@ def __init__( working_dir: Working directory to complete operations in rule_view_dir: Location of directory containing components with to read rules from rule_transformer: Transformer to use for rule transformation to TrestleRule - skip_model_list: List of rule names to be skipped during processing + filter: Optional filter to apply to the task to include or exclude models from processing. Notes: The rule_view_dir is expected to be a directory containing directories of @@ -66,7 +66,7 @@ def __init__( self._rule_view_dir = rules_view_dir self._rule_transformer: ToRulesYAMLTransformer = rule_transformer - super().__init__(working_dir, skip_model_list) + super().__init__(working_dir, filter) def execute(self) -> int: """Execute task"""