diff --git a/deploy-agent/deployd/BUILD.bazel b/deploy-agent/deployd/BUILD.bazel index 51e968c08b..6974d4eda5 100644 --- a/deploy-agent/deployd/BUILD.bazel +++ b/deploy-agent/deployd/BUILD.bazel @@ -19,6 +19,15 @@ py_binary( ], ) +py_binary( + name = "deploy-downloader", + srcs = ["download/downloader.py"], + main = "download/downloader.py", + deps = [ + ":lib", + ], +) + py_library( name = "third_party", deps = [ diff --git a/deploy-agent/deployd/__init__.py b/deploy-agent/deployd/__init__.py index 0f0ac668bb..cf42a605a5 100644 --- a/deploy-agent/deployd/__init__.py +++ b/deploy-agent/deployd/__init__.py @@ -27,4 +27,4 @@ # 2: puppet applied successfully with changes PUPPET_SUCCESS_EXIT_CODES = [0, 2] -__version__ = '1.2.57' \ No newline at end of file +__version__ = '1.2.58' \ No newline at end of file diff --git a/deploy-agent/deployd/common/config.py b/deploy-agent/deployd/common/config.py index 369acca1ba..ff7e9b25c2 100644 --- a/deploy-agent/deployd/common/config.py +++ b/deploy-agent/deployd/common/config.py @@ -15,6 +15,8 @@ import logging import os +import json + from typing import Any, List, Optional from deployd import __version__ @@ -292,8 +294,19 @@ def get_stage_type_key(self) -> Optional[str]: def get_facter_account_id_key(self) -> str: return self.get_var('account_id_key', 'ec2_metadata.identity-credentials.ec2.info') + def _get_download_allow_list(self, key: str) -> List: + allow_list_str = self.get_var(key, '[]') + allow_list = [] + try: + allow_list = json.loads(allow_list_str) + except json.JSONDecodeError: + log.error(f"Error: The string {allow_list_str} could not be converted to a list.") + return allow_list + + def get_http_download_allow_list(self) -> List: - return self.get_var('http_download_allow_list', []) + return self._get_download_allow_list('http_download_allow_list') + def get_s3_download_allow_list(self) -> List: - return self.get_var('s3_download_allow_list', []) + return self._get_download_allow_list('s3_download_allow_list') diff --git a/deploy-agent/deployd/download/download_helper_factory.py b/deploy-agent/deployd/download/download_helper_factory.py index d483e76b7e..f992f77c8c 100644 --- a/deploy-agent/deployd/download/download_helper_factory.py +++ b/deploy-agent/deployd/download/download_helper_factory.py @@ -3,33 +3,34 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging +from logging import Logger, getLogger from typing import Optional from future.moves.urllib.parse import urlparse from boto.s3.connection import S3Connection +from deployd.download.download_helper import DownloadHelper from deployd.download.s3_download_helper import S3DownloadHelper from deployd.download.http_download_helper import HTTPDownloadHelper from deployd.download.local_download_helper import LocalDownloadHelper -log = logging.getLogger(__name__) +log: Logger = getLogger(__name__) class DownloadHelperFactory(object): @staticmethod - def gen_downloader(url, config) -> Optional[S3DownloadHelper|LocalDownloadHelper|HTTPDownloadHelper]: + def gen_downloader(url, config) -> Optional[DownloadHelper]: url_parse = urlparse(url) if url_parse.scheme == 's3': aws_access_key_id = config.get_aws_access_key() @@ -38,8 +39,8 @@ def gen_downloader(url, config) -> Optional[S3DownloadHelper|LocalDownloadHelper log.error("aws access key id and secret access key not found") return None aws_conn = S3Connection(aws_access_key_id, aws_secret_access_key, True) - return S3DownloadHelper(url, aws_conn) + return S3DownloadHelper(local_full_fn=url, aws_connection=aws_conn, url=None, config=config) elif url_parse.scheme == 'file': - return LocalDownloadHelper(url) + return LocalDownloadHelper(url=url) else: - return HTTPDownloadHelper(url) + return HTTPDownloadHelper(url=url, config=config) diff --git a/deploy-agent/deployd/download/http_download_helper.py b/deploy-agent/deployd/download/http_download_helper.py index 1bdfa3198d..e795457dcc 100644 --- a/deploy-agent/deployd/download/http_download_helper.py +++ b/deploy-agent/deployd/download/http_download_helper.py @@ -90,4 +90,8 @@ def validate_source(self) -> bool: allow_list = self._config.get_http_download_allow_list() - return domain in allow_list if allow_list else True + if not allow_list or domain in allow_list: + return True + else: + log.error(f"{domain} is not in the allow list: {allow_list}.") + return False \ No newline at end of file diff --git a/deploy-agent/deployd/download/s3_download_helper.py b/deploy-agent/deployd/download/s3_download_helper.py index 479c6c09e6..afb2e23204 100644 --- a/deploy-agent/deployd/download/s3_download_helper.py +++ b/deploy-agent/deployd/download/s3_download_helper.py @@ -27,10 +27,10 @@ class S3DownloadHelper(DownloadHelper): - def __init__(self, local_full_fn, aws_connection=None, url=None) -> None: + def __init__(self, local_full_fn, aws_connection=None, url=None, config=None) -> None: super(S3DownloadHelper, self).__init__(local_full_fn) self._s3_matcher = "^s3://(?P[a-zA-Z0-9\-_]+)/(?P[a-zA-Z0-9\-_/\.]+)/?" - self._config = Config() + self._config = config if config else Config() if aws_connection: self._aws_connection = aws_connection else: @@ -81,4 +81,8 @@ def validate_source(self) -> bool: tags = {'type': 's3', 'url': self._url, 'bucket' : self._bucket_name} create_sc_increment(DOWNLOAD_VALIDATE_METRICS, tags=tags) - return self._bucket_name in allow_list if allow_list else True + if not allow_list or self._bucket_name in allow_list: + return True + else: + log.error(f"{self._bucket_name} is not in the allow list: {allow_list}.") + return False \ No newline at end of file diff --git a/deploy-agent/tests/unit/deploy/download/test_http_download_helper.py b/deploy-agent/tests/unit/deploy/download/test_http_download_helper.py index 4ceb35a11f..1aa255acb5 100644 --- a/deploy-agent/tests/unit/deploy/download/test_http_download_helper.py +++ b/deploy-agent/tests/unit/deploy/download/test_http_download_helper.py @@ -13,6 +13,7 @@ # limitations under the License. from deployd.download.http_download_helper import HTTPDownloadHelper +from deployd.common.config import Config import unittest from unittest import mock import logging @@ -24,16 +25,17 @@ class TestHttpDownloadHelper(unittest.TestCase): def setUp(self): - self.downloader = HTTPDownloadHelper(url="https://deploy1.com") + self.config = Config() + self.downloader = HTTPDownloadHelper(url="https://deploy1.com", config=self.config) - @mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list') + @mock.patch('deployd.common.config.Config.get_http_download_allow_list') def test_validate_url_with_allow_list(self, mock_get_http_download_allow_list): mock_get_http_download_allow_list.return_value = ['deploy1.com', 'deploy2.com'] result = self.downloader.validate_source() self.assertTrue(result) mock_get_http_download_allow_list.assert_called_once() - @mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list') + @mock.patch('deployd.common.config.Config.get_http_download_allow_list') def test_validate_url_with_non_https(self, mock_get_http_download_allow_list): downloader = HTTPDownloadHelper(url="http://deploy1.com") mock_get_http_download_allow_list.return_value = ['deploy1', 'deploy2'] @@ -42,9 +44,9 @@ def test_validate_url_with_non_https(self, mock_get_http_download_allow_list): self.assertFalse(result) mock_get_http_download_allow_list.assert_not_called() - @mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list') + @mock.patch('deployd.common.config.Config.get_http_download_allow_list') def test_validate_url_without_allow_list(self, mock_get_http_download_allow_list): - mock_get_http_download_allow_list.return_value = None + mock_get_http_download_allow_list.return_value = [] result = self.downloader.validate_source() self.assertTrue(result) diff --git a/deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py b/deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py index 8103bab6a3..ce06819f72 100644 --- a/deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py +++ b/deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py @@ -16,7 +16,7 @@ import unittest from unittest import mock import logging - +from deployd.common.config import Config logger = logging.getLogger() logger.level = logging.DEBUG @@ -27,9 +27,10 @@ class TestS3DownloadHelper(unittest.TestCase): def setUp(self, mock_aws_key, mock_aws_secret): mock_aws_key.return_value = "test_key" mock_aws_secret.return_value= "test_secret" - self.downloader = S3DownloadHelper(local_full_fn='', url="s3://bucket1/key1") + self.config = Config() + self.downloader = S3DownloadHelper(local_full_fn='', aws_connection=None, url="s3://bucket1/key1", config=self.config) - @mock.patch('deployd.download.s3_download_helper.Config.get_s3_download_allow_list') + @mock.patch('deployd.common.config.Config.get_s3_download_allow_list') def test_validate_url_with_allow_list(self, mock_get_s3_download_allow_list): mock_get_s3_download_allow_list.return_value = ['bucket1', 'bucket2', 'bucket3'] result = self.downloader.validate_source() @@ -40,14 +41,21 @@ def test_validate_url_with_allow_list(self, mock_get_s3_download_allow_list): result = self.downloader.validate_source() self.assertFalse(result) - @mock.patch('deployd.download.s3_download_helper.Config.get_s3_download_allow_list') + @mock.patch('deployd.common.config.Config.get_s3_download_allow_list') def test_validate_url_without_allow_list(self, mock_get_s3_download_allow_list): - mock_get_s3_download_allow_list.return_value = None + mock_get_s3_download_allow_list.return_value = [] result = self.downloader.validate_source() self.assertTrue(result) mock_get_s3_download_allow_list.assert_called_once() + @mock.patch('deployd.common.config.Config.get_s3_download_allow_list') + def test_validate_url_without_allow_list(self, mock_get_s3_download_allow_list): + mock_get_s3_download_allow_list.return_value = [] + result = self.downloader.validate_source() + + self.assertTrue(result) + mock_get_s3_download_allow_list.assert_called_once() if __name__ == '__main__': unittest.main()