From 7286e48df883ea3b41125a2d9dc3baff9d97352e Mon Sep 17 00:00:00 2001 From: ntascii Date: Fri, 20 Oct 2023 07:30:31 +0000 Subject: [PATCH] validate download source --- deploy-agent/deployd/__init__.py | 8 +-- deploy-agent/deployd/common/config.py | 14 +++-- .../deployd/download/download_helper.py | 13 +++-- .../deployd/download/http_download_helper.py | 30 ++++++++++- .../deployd/download/local_download_helper.py | 7 ++- .../deployd/download/s3_download_helper.py | 41 ++++++++------ deploy-agent/tests/BUILD.bazel | 20 +++++-- .../download/test_http_download_helper.py | 54 +++++++++++++++++++ .../download/test_s3_download_helper.py | 53 ++++++++++++++++++ 9 files changed, 206 insertions(+), 34 deletions(-) create mode 100644 deploy-agent/tests/unit/deploy/download/test_http_download_helper.py create mode 100644 deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py diff --git a/deploy-agent/deployd/__init__.py b/deploy-agent/deployd/__init__.py index a3fb64d706..f0f0343a7b 100644 --- a/deploy-agent/deployd/__init__.py +++ b/deploy-agent/deployd/__init__.py @@ -3,9 +3,9 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -25,5 +25,5 @@ # 0: puppet applied successfully with no changes # 2: puppet applied successfully with changes PUPPET_SUCCESS_EXIT_CODES = [0, 2] - -__version__ = '1.2.50' + +__version__ = '1.2.55' diff --git a/deploy-agent/deployd/common/config.py b/deploy-agent/deployd/common/config.py index 6ae58cca42..58caaabe07 100644 --- a/deploy-agent/deployd/common/config.py +++ b/deploy-agent/deployd/common/config.py @@ -4,9 +4,9 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -244,7 +244,7 @@ def get_puppet_exit_code_file_path(self): def get_daemon_sleep_time(self): return self.get_intvar("daemon_sleep_time", 30) - + def get_init_sleep_time(self): return self.get_intvar("init_sleep_time", 50) @@ -267,7 +267,7 @@ def get_facter_name_key(self): def get_facter_group_key(self): return self.get_var('agent_group_key', None) - + def get_verify_https_certificate(self): return self.get_var('verify_https_certificate', 'False') @@ -288,3 +288,9 @@ def get_stage_type_key(self): def get_facter_account_id_key(self): return self.get_var('account_id_key', 'ec2_metadata.identity-credentials.ec2.info') + + def get_http_download_allow_list(self): + return self.get_var('http_download_allow_list', []) + + def get_s3_download_allow_list(self): + return self.get_var('s3_download_allow_list', []) diff --git a/deploy-agent/deployd/download/download_helper.py b/deploy-agent/deployd/download/download_helper.py index 439e05a4da..5d760ab0ed 100644 --- a/deploy-agent/deployd/download/download_helper.py +++ b/deploy-agent/deployd/download/download_helper.py @@ -3,9 +3,9 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,8 +18,9 @@ log = logging.getLogger(__name__) +DOWNLOAD_VALIDATE_METRICS = 'deployd.stats.download.validate' -class DownloadHelper(object): +class DownloadHelper(metaclass=abc.ABCMeta): def __init__(self, url): self._url = url @@ -40,6 +41,10 @@ def md5_file(file_path): md5.update(chunk) return md5.hexdigest() - @abc.abstractproperty + @abc.abstractmethod def download(self, local_full_fn): pass + + @abc.abstractmethod + def validate_source(self): + pass diff --git a/deploy-agent/deployd/download/http_download_helper.py b/deploy-agent/deployd/download/http_download_helper.py index a3928c65bd..43148099ed 100644 --- a/deploy-agent/deployd/download/http_download_helper.py +++ b/deploy-agent/deployd/download/http_download_helper.py @@ -14,11 +14,14 @@ # limitations under the License. from deployd.common.caller import Caller -from .downloader import Status -from deployd.download.download_helper import DownloadHelper +from deployd.common.status_code import Status +from deployd.common.config import Config +from deployd.download.download_helper import DownloadHelper, DOWNLOAD_VALIDATE_METRICS +from deployd.common.stats import create_sc_increment import os import requests import logging +from urllib.parse import ParseResult, urlparse requests.packages.urllib3.disable_warnings() log = logging.getLogger(__name__) @@ -26,6 +29,10 @@ class HTTPDownloadHelper(DownloadHelper): + def __init__(self, url=None, config=None): + super().__init__(url) + self._config = config if config else Config() + def _download_files(self, local_full_fn): download_cmd = ['curl', '-o', local_full_fn, '-fksS', self._url] log.info('Running command: {}'.format(' '.join(download_cmd))) @@ -41,6 +48,9 @@ def _download_files(self, local_full_fn): def download(self, local_full_fn): log.info("Start to download from url {} to {}".format( self._url, local_full_fn)) + if not self.validate_source(): + log.error(f'Invalid url: {self._url}. Skip downloading.') + return Status.FAILED status_code = self._download_files(local_full_fn) if status_code != Status.SUCCEEDED: @@ -65,3 +75,19 @@ def download(self, local_full_fn): except requests.ConnectionError: log.error('Could not connect to: {}'.format(self._url)) return Status.FAILED + + def validate_source(self): + tags = {'type': 'http', 'url': self._url} + create_sc_increment(DOWNLOAD_VALIDATE_METRICS, tags=tags) + + parsed_url: ParseResult = urlparse(self._url) + if not parsed_url.scheme == 'https': + return False + + domain: str = parsed_url.netloc + if not domain: + return False + + allow_list = self._config.get_http_download_allow_list() + + return domain in allow_list if allow_list else True diff --git a/deploy-agent/deployd/download/local_download_helper.py b/deploy-agent/deployd/download/local_download_helper.py index 3d457824c6..47d1188d42 100644 --- a/deploy-agent/deployd/download/local_download_helper.py +++ b/deploy-agent/deployd/download/local_download_helper.py @@ -3,9 +3,9 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -45,3 +45,6 @@ def download(self, local_full_fn): if error != Status.SUCCEEDED: log.error('Failed to download the local tar ball for {}'.format(local_full_fn)) return error + + def validate_source(self): + return True \ No newline at end of file diff --git a/deploy-agent/deployd/download/s3_download_helper.py b/deploy-agent/deployd/download/s3_download_helper.py index 26b714982a..d1ff0d8458 100644 --- a/deploy-agent/deployd/download/s3_download_helper.py +++ b/deploy-agent/deployd/download/s3_download_helper.py @@ -3,9 +3,9 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,8 +19,8 @@ from boto.s3.connection import S3Connection from deployd.common.config import Config from deployd.common.status_code import Status -from deployd.download.download_helper import DownloadHelper - +from deployd.download.download_helper import DownloadHelper, DOWNLOAD_VALIDATE_METRICS +from deployd.common.stats import create_sc_increment log = logging.getLogger(__name__) @@ -30,27 +30,31 @@ class S3DownloadHelper(DownloadHelper): def __init__(self, local_full_fn, aws_connection=None, url=None): super(S3DownloadHelper, self).__init__(local_full_fn) self._s3_matcher = "^s3://(?P[a-zA-Z0-9\-_]+)/(?P[a-zA-Z0-9\-_/\.]+)/?" + self._config = Config() if aws_connection: self._aws_connection = aws_connection else: - config = Config() - aws_access_key_id = config.get_aws_access_key() - aws_secret_access_key = config.get_aws_access_secret() + aws_access_key_id = self._config.get_aws_access_key() + aws_secret_access_key = self._config.get_aws_access_secret() self._aws_connection = S3Connection(aws_access_key_id, aws_secret_access_key, True) + if url: self._url = url + s3url_parse = re.match(self._s3_matcher, self._url) + self._bucket_name = s3url_parse.group("BUCKET") + self._key = s3url_parse.group("KEY") + def download(self, local_full_fn): - s3url_parse = re.match(self._s3_matcher, self._url) - bucket_name = s3url_parse.group("BUCKET") - key = s3url_parse.group("KEY") - log.info("Start to download file {} from s3 bucket {} to {}".format( - key, bucket_name, local_full_fn)) + log.info(f"Start to download file {self._key} from s3 bucket {self._bucket_name} to {local_full_fn}") + if not self.validate_source(): + log.error(f'Invalid url: {self._url}. Skip downloading.') + return Status.FAILED try: - filekey = self._aws_connection.get_bucket(bucket_name).get_key(key) + filekey = self._aws_connection.get_bucket(self._bucket_name).get_key(self._key) if filekey is None: - log.error("s3 key {} not found".format(key)) + log.error("s3 key {} not found".format(self._key)) return Status.FAILED filekey.get_contents_to_filename(local_full_fn) @@ -58,7 +62,7 @@ def download(self, local_full_fn): if "-" not in etag: if etag.startswith('"') and etag.endswith('"'): etag = etag[1:-1] - + md5 = self.md5_file(local_full_fn) if md5 != etag: log.error("MD5 verification failed. tarball is corrupt.") @@ -71,3 +75,10 @@ def download(self, local_full_fn): except Exception: log.error("Failed to get package from s3: {}".format(traceback.format_exc())) return Status.FAILED + + def validate_source(self): + allow_list = self._config.get_s3_download_allow_list() + tags = {'type': 's3', 'url': self._url, 'bucket' : self._bucket_name} + create_sc_increment(DOWNLOAD_VALIDATE_METRICS, tags=tags) + + return self._bucket_name in allow_list if allow_list else True \ No newline at end of file diff --git a/deploy-agent/tests/BUILD.bazel b/deploy-agent/tests/BUILD.bazel index 6814450914..67988bfa71 100644 --- a/deploy-agent/tests/BUILD.bazel +++ b/deploy-agent/tests/BUILD.bazel @@ -1,21 +1,21 @@ load("@rules_python//python:defs.bzl", "py_library") py_test( - name = "test_base_client", + name = "test_base_client", srcs = ['unit/deploy/client/test_base_client.py'], deps = ["test_lib"], python_version = "PY3", ) py_test( - name = "test_serverless_client", + name = "test_serverless_client", srcs = ['unit/deploy/client/test_serverless_client.py'], deps = ["test_lib"], python_version = "PY3", ) py_test( - name = "test_agent", + name = "test_agent", srcs = ['unit/deploy/server/test_agent.py'], deps = ["test_lib"], python_version = "PY3", @@ -28,6 +28,20 @@ py_test( python_version = "PY3", ) +py_test( + name = "test_s3_download_helper", + srcs = ['unit/deploy/download/test_s3_download_helper.py'], + deps = ["test_lib"], + python_version = "PY3", +) + +py_test( + name = "test_http_download_helper", + srcs = ['unit/deploy/download/test_http_download_helper.py'], + deps = ["test_lib"], + python_version = "PY3", +) + py_library( name = "test_lib", srcs = ["__init__.py"], diff --git a/deploy-agent/tests/unit/deploy/download/test_http_download_helper.py b/deploy-agent/tests/unit/deploy/download/test_http_download_helper.py new file mode 100644 index 0000000000..4ceb35a11f --- /dev/null +++ b/deploy-agent/tests/unit/deploy/download/test_http_download_helper.py @@ -0,0 +1,54 @@ +# Copyright 2016 Pinterest, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from deployd.download.http_download_helper import HTTPDownloadHelper +import unittest +from unittest import mock +import logging + +logger = logging.getLogger() +logger.level = logging.DEBUG + + +class TestHttpDownloadHelper(unittest.TestCase): + + def setUp(self): + self.downloader = HTTPDownloadHelper(url="https://deploy1.com") + + @mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list') + def test_validate_url_with_allow_list(self, mock_get_http_download_allow_list): + mock_get_http_download_allow_list.return_value = ['deploy1.com', 'deploy2.com'] + result = self.downloader.validate_source() + self.assertTrue(result) + mock_get_http_download_allow_list.assert_called_once() + + @mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list') + def test_validate_url_with_non_https(self, mock_get_http_download_allow_list): + downloader = HTTPDownloadHelper(url="http://deploy1.com") + mock_get_http_download_allow_list.return_value = ['deploy1', 'deploy2'] + result = downloader.validate_source() + + self.assertFalse(result) + mock_get_http_download_allow_list.assert_not_called() + + @mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list') + def test_validate_url_without_allow_list(self, mock_get_http_download_allow_list): + mock_get_http_download_allow_list.return_value = None + result = self.downloader.validate_source() + + self.assertTrue(result) + mock_get_http_download_allow_list.assert_called_once() + +if __name__ == '__main__': + unittest.main() diff --git a/deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py b/deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py new file mode 100644 index 0000000000..8103bab6a3 --- /dev/null +++ b/deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py @@ -0,0 +1,53 @@ +# Copyright 2016 Pinterest, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from deployd.download.s3_download_helper import S3DownloadHelper +import unittest +from unittest import mock +import logging + +logger = logging.getLogger() +logger.level = logging.DEBUG + + +class TestS3DownloadHelper(unittest.TestCase): + @mock.patch('deployd.download.s3_download_helper.Config.get_aws_access_key') + @mock.patch('deployd.download.s3_download_helper.Config.get_aws_access_secret') + def setUp(self, mock_aws_key, mock_aws_secret): + mock_aws_key.return_value = "test_key" + mock_aws_secret.return_value= "test_secret" + self.downloader = S3DownloadHelper(local_full_fn='', url="s3://bucket1/key1") + + @mock.patch('deployd.download.s3_download_helper.Config.get_s3_download_allow_list') + def test_validate_url_with_allow_list(self, mock_get_s3_download_allow_list): + mock_get_s3_download_allow_list.return_value = ['bucket1', 'bucket2', 'bucket3'] + result = self.downloader.validate_source() + self.assertTrue(result) + mock_get_s3_download_allow_list.assert_called_once() + + mock_get_s3_download_allow_list.return_value = ['bucket3'] + result = self.downloader.validate_source() + self.assertFalse(result) + + @mock.patch('deployd.download.s3_download_helper.Config.get_s3_download_allow_list') + def test_validate_url_without_allow_list(self, mock_get_s3_download_allow_list): + mock_get_s3_download_allow_list.return_value = None + result = self.downloader.validate_source() + + self.assertTrue(result) + mock_get_s3_download_allow_list.assert_called_once() + + +if __name__ == '__main__': + unittest.main()