diff --git a/src/rpdk/core/cli.py b/src/rpdk/core/cli.py index 8a4d677c..db65029d 100644 --- a/src/rpdk/core/cli.py +++ b/src/rpdk/core/cli.py @@ -12,6 +12,7 @@ from .build_image import setup_subparser as build_image_setup_subparser from .data_loaders import resource_yaml from .exceptions import DownstreamError, SysExitRecommendedError +from .extensions import setup_subparsers as extensions_setup_subparser from .generate import setup_subparser as generate_setup_subparser from .init import setup_subparser as init_setup_subparser from .invoke import setup_subparser as invoke_setup_subparser @@ -88,6 +89,7 @@ def no_command(args): invoke_setup_subparser(subparsers, parents) unittest_patch_setup_subparser(subparsers, parents) build_image_setup_subparser(subparsers, parents) + extensions_setup_subparser(subparsers, parents) args = parser.parse_args(args=args_in) setup_logging(args.verbose) diff --git a/src/rpdk/core/extensions.py b/src/rpdk/core/extensions.py new file mode 100644 index 00000000..e6d5b061 --- /dev/null +++ b/src/rpdk/core/extensions.py @@ -0,0 +1,18 @@ +from .plugin_registry import get_extensions + + +def _check_command_name_collision(subparsers, command_name): + if command_name in subparsers.choices: + raise RuntimeError( + f'"{command_name}" is already registered as an extension. Please use a different name.' + ) + + +def setup_subparsers(subparsers, parents): + extensions = get_extensions() + + for extension_cls in extensions.values(): + extension = extension_cls()() + _check_command_name_collision(subparsers, extension.command_name) + parser = subparsers.add_parser(extension.command_name, parents=parents) + extension.setup_parser(parser) diff --git a/src/rpdk/core/plugin_base.py b/src/rpdk/core/plugin_base.py index 25683ec5..d90ba708 100644 --- a/src/rpdk/core/plugin_base.py +++ b/src/rpdk/core/plugin_base.py @@ -55,3 +55,19 @@ def generate(self, project): @abstractmethod def package(self, project, zip_file): pass + + +class ExtensionPlugin(ABC): + COMMAND_NAME = None + + @property + def command_name(self): + if not self.COMMAND_NAME: + raise RuntimeError( + "Set COMMAND_NAME to the command you want to extend cfn with: `cfn COMMAND_NAME`." + ) + return self.COMMAND_NAME + + @abstractmethod + def setup_parser(self, parser): + pass diff --git a/src/rpdk/core/plugin_registry.py b/src/rpdk/core/plugin_registry.py index c4b05fb8..a5a4a238 100644 --- a/src/rpdk/core/plugin_registry.py +++ b/src/rpdk/core/plugin_registry.py @@ -23,5 +23,14 @@ def get_parsers(): return parsers +def get_extensions(): + extensions = { + entry_point.name: entry_point.load + for entry_point in pkg_resources.iter_entry_points("rpdk.v1.extensions") + } + + return extensions + + def load_plugin(language): return PLUGIN_REGISTRY[language]()() diff --git a/tests/test_cli.py b/tests/test_cli.py index f3cff6de..2d3153f0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -77,6 +77,15 @@ def test_main_no_args_prints_help(capsys): assert "--help" in out +def test_main_setup_extensions(): + with patch( + "rpdk.core.cli.extensions_setup_subparser" + ) as extensions_setup_subparser: + main(args_in=[]) + + extensions_setup_subparser.assert_called_once() + + def test_main_version_arg_prints_version(capsys): main(args_in=["--version"]) out, err = capsys.readouterr() diff --git a/tests/test_extensions.py b/tests/test_extensions.py new file mode 100644 index 00000000..c9443e77 --- /dev/null +++ b/tests/test_extensions.py @@ -0,0 +1,60 @@ +import argparse +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from rpdk.core.extensions import setup_subparsers + + +class ExtensionTest(TestCase): + def test_setup_subparsers(self): # pylint: disable=no-self-use + expeted_command_name = "expected-command-name" + + mock_extension = MagicMock() + mock_extension.command_name = expeted_command_name + + mock_extension_entry_point = MagicMock() + mock_extension_entry_point.return_value.return_value = mock_extension + + mock_extension_entry_points = {"key": mock_extension_entry_point} + + subparsers, parents, parser = MagicMock(), MagicMock(), MagicMock() + subparsers.add_parser.return_value = parser + + with patch("rpdk.core.extensions.get_extensions") as mock_get_extensions: + mock_get_extensions.return_value = mock_extension_entry_points + setup_subparsers(subparsers, parents) + + mock_extension.setup_parser.assert_called_once_with(parser) + subparsers.add_parser.assert_called_with(expeted_command_name, parents=parents) + + def test_setup_subparsers_should_raise_error_when_collision_occur(self): + command_name = "command-name" + + mock_extension_1, mock_extension_2 = MagicMock(), MagicMock() + mock_extension_1.command_name = command_name + mock_extension_2.command_name = command_name + + mock_extension_1_entry_point = MagicMock() + mock_extension_1_entry_point.return_value.return_value = mock_extension_1 + + mock_extension_2_entry_point = MagicMock() + mock_extension_2_entry_point.return_value.return_value = mock_extension_2 + + mock_extension_entry_points = { + "key_1": mock_extension_1_entry_point, + "key_2": mock_extension_2_entry_point, + } + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + + with patch( + "rpdk.core.extensions.get_extensions" + ) as mock_get_extensions, self.assertRaises(RuntimeError) as context: + mock_get_extensions.return_value = mock_extension_entry_points + setup_subparsers(subparsers, []) + + assert ( + str(context.exception) + == '"command-name" is already registered as an extension. Please use a different name.' + ) diff --git a/tests/test_plugin_base.py b/tests/test_plugin_base.py index 63ea2900..554b1f14 100644 --- a/tests/test_plugin_base.py +++ b/tests/test_plugin_base.py @@ -5,7 +5,11 @@ import pytest from rpdk.core.filters import FILTER_REGISTRY -from rpdk.core.plugin_base import LanguagePlugin, __name__ as plugin_base_name +from rpdk.core.plugin_base import ( + ExtensionPlugin, + LanguagePlugin, + __name__ as plugin_base_name, +) class TestLanguagePlugin(LanguagePlugin): @@ -22,7 +26,7 @@ def package(self, project, zip_file): @pytest.fixture -def plugin(): +def language_plugin(): return TestLanguagePlugin() @@ -34,20 +38,20 @@ def test_language_plugin_module_not_set(): plugin._module_name # pylint: disable=pointless-statement -def test_language_plugin_init_no_op(plugin): - plugin.init(None) +def test_language_plugin_init_no_op(language_plugin): + language_plugin.init(None) -def test_language_plugin_generate_no_op(plugin): - plugin.generate(None) +def test_language_plugin_generate_no_op(language_plugin): + language_plugin.generate(None) -def test_language_plugin_package_no_op(plugin): - plugin.package(None, None) +def test_language_plugin_package_no_op(language_plugin): + language_plugin.package(None, None) -def test_language_plugin_setup_jinja_env_defaults(plugin): - env = plugin._setup_jinja_env() +def test_language_plugin_setup_jinja_env_defaults(language_plugin): + env = language_plugin._setup_jinja_env() assert env.loader assert env.autoescape @@ -57,10 +61,10 @@ def test_language_plugin_setup_jinja_env_defaults(plugin): assert env.get_template("test.txt") -def test_language_plugin_setup_jinja_env_overrides(plugin): +def test_language_plugin_setup_jinja_env_overrides(language_plugin): loader = object() autoescape = object() - env = plugin._setup_jinja_env(autoescape=autoescape, loader=loader) + env = language_plugin._setup_jinja_env(autoescape=autoescape, loader=loader) assert env.loader is loader assert env.autoescape is autoescape @@ -68,17 +72,45 @@ def test_language_plugin_setup_jinja_env_overrides(plugin): assert name in env.filters -def test_language_plugin_setup_jinja_env_no_spec(plugin): +def test_language_plugin_setup_jinja_env_no_spec(language_plugin): with patch( "rpdk.core.plugin_base.importlib.util.find_spec", return_value=None ) as mock_spec, patch("rpdk.core.plugin_base.PackageLoader") as mock_loader: - env = plugin._setup_jinja_env() + env = language_plugin._setup_jinja_env() - mock_spec.assert_called_once_with(plugin._module_name) - mock_loader.assert_has_calls([call(plugin._module_name), call(plugin_base_name)]) + mock_spec.assert_called_once_with(language_plugin._module_name) + mock_loader.assert_has_calls( + [call(language_plugin._module_name), call(plugin_base_name)] + ) assert env.loader assert env.autoescape for name in FILTER_REGISTRY: assert name in env.filters + + +class TestExtensionPlugin(ExtensionPlugin): + COMMAND_NAME = "test-extension" + + def setup_parser(self, parser): + super().setup_parser(parser) + + +@pytest.fixture +def extension_plugin(): + return TestExtensionPlugin() + + +def test_extension_plugin_command_name(extension_plugin): + assert extension_plugin.command_name == "test-extension" + + +def test_extension_plugin_command_name_error(extension_plugin): + extension_plugin.COMMAND_NAME = None + with pytest.raises(RuntimeError): + extension_plugin.command_name # pylint: disable=pointless-statement + + +def test_extension_plugin_setup_parser_no_op(extension_plugin): + extension_plugin.setup_parser(None) diff --git a/tests/test_plugin_registry.py b/tests/test_plugin_registry.py index c6e5abf3..f3145778 100644 --- a/tests/test_plugin_registry.py +++ b/tests/test_plugin_registry.py @@ -1,6 +1,6 @@ from unittest.mock import Mock, patch -from rpdk.core.plugin_registry import load_plugin +from rpdk.core.plugin_registry import get_extensions, load_plugin def test_load_plugin(): @@ -11,3 +11,21 @@ def test_load_plugin(): load_plugin("test") plugin.assert_called_once_with() plugin.return_value.assert_called_once_with() + + +def test_get_extensions(): + mock_entrypoint_1 = Mock() + mock_entrypoint_2 = Mock() + + patch_iter_entry_points = patch( + "rpdk.core.plugin_registry.pkg_resources.iter_entry_points" + ) + with patch_iter_entry_points as mock_iter_entry_points: + mock_iter_entry_points.return_value = [mock_entrypoint_1, mock_entrypoint_2] + + extensions = get_extensions() + + assert extensions == { + mock_entrypoint_1.name: mock_entrypoint_1.load, + mock_entrypoint_2.name: mock_entrypoint_2.load, + }