Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate download source on deploy agent #1366

Merged
merged 2 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deploy-agent/deployd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -26,5 +26,5 @@
# 0: puppet applied successfully with no changes
# 2: puppet applied successfully with changes
PUPPET_SUCCESS_EXIT_CODES = [0, 2]
__version__ = '1.2.54'

__version__ = '1.2.55'
6 changes: 6 additions & 0 deletions deploy-agent/deployd/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,9 @@ def get_stage_type_key(self):

def get_facter_account_id_key(self):
return self.get_var('account_id_key', 'ec2_metadata.identity-credentials.ec2.info')

def get_http_download_allow_list(self):
return self.get_var('http_download_allow_list', [])

def get_s3_download_allow_list(self):
return self.get_var('s3_download_allow_list', [])
13 changes: 9 additions & 4 deletions deploy-agent/deployd/download/download_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -18,8 +18,9 @@

log = logging.getLogger(__name__)

DOWNLOAD_VALIDATE_METRICS = 'deployd.stats.download.validate'

class DownloadHelper(object):
class DownloadHelper(metaclass=abc.ABCMeta):

def __init__(self, url):
self._url = url
Expand All @@ -40,6 +41,10 @@ def md5_file(file_path):
md5.update(chunk)
return md5.hexdigest()

@abc.abstractproperty
@abc.abstractmethod
def download(self, local_full_fn):
pass

@abc.abstractmethod
def validate_source(self):
pass
30 changes: 28 additions & 2 deletions deploy-agent/deployd/download/http_download_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,25 @@
# limitations under the License.

from deployd.common.caller import Caller
from .downloader import Status
from deployd.download.download_helper import DownloadHelper
from deployd.common.status_code import Status
from deployd.common.config import Config
from deployd.download.download_helper import DownloadHelper, DOWNLOAD_VALIDATE_METRICS
from deployd.common.stats import create_sc_increment
import os
import requests
import logging
from urllib.parse import ParseResult, urlparse
requests.packages.urllib3.disable_warnings()

log = logging.getLogger(__name__)


class HTTPDownloadHelper(DownloadHelper):

def __init__(self, url=None, config=None):
super().__init__(url)
self._config = config if config else Config()

def _download_files(self, local_full_fn):
download_cmd = ['curl', '-o', local_full_fn, '-fksS', self._url]
log.info('Running command: {}'.format(' '.join(download_cmd)))
Expand All @@ -41,6 +48,9 @@ def _download_files(self, local_full_fn):
def download(self, local_full_fn):
log.info("Start to download from url {} to {}".format(
self._url, local_full_fn))
if not self.validate_source():
log.error(f'Invalid url: {self._url}. Skip downloading.')
return Status.FAILED

status_code = self._download_files(local_full_fn)
if status_code != Status.SUCCEEDED:
Expand All @@ -65,3 +75,19 @@ def download(self, local_full_fn):
except requests.ConnectionError:
log.error('Could not connect to: {}'.format(self._url))
return Status.FAILED

def validate_source(self):
tags = {'type': 'http', 'url': self._url}
create_sc_increment(DOWNLOAD_VALIDATE_METRICS, tags=tags)

parsed_url: ParseResult = urlparse(self._url)
if not parsed_url.scheme == 'https':
return False

domain: str = parsed_url.netloc
if not domain:
return False

allow_list = self._config.get_http_download_allow_list()

return domain in allow_list if allow_list else True
7 changes: 5 additions & 2 deletions deploy-agent/deployd/download/local_download_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -45,3 +45,6 @@ def download(self, local_full_fn):
if error != Status.SUCCEEDED:
log.error('Failed to download the local tar ball for {}'.format(local_full_fn))
return error

def validate_source(self):
return True
41 changes: 26 additions & 15 deletions deploy-agent/deployd/download/s3_download_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -19,8 +19,8 @@
from boto.s3.connection import S3Connection
from deployd.common.config import Config
from deployd.common.status_code import Status
from deployd.download.download_helper import DownloadHelper

from deployd.download.download_helper import DownloadHelper, DOWNLOAD_VALIDATE_METRICS
from deployd.common.stats import create_sc_increment

log = logging.getLogger(__name__)

Expand All @@ -30,35 +30,39 @@ class S3DownloadHelper(DownloadHelper):
def __init__(self, local_full_fn, aws_connection=None, url=None):
super(S3DownloadHelper, self).__init__(local_full_fn)
self._s3_matcher = "^s3://(?P<BUCKET>[a-zA-Z0-9\-_]+)/(?P<KEY>[a-zA-Z0-9\-_/\.]+)/?"
self._config = Config()
if aws_connection:
self._aws_connection = aws_connection
else:
config = Config()
aws_access_key_id = config.get_aws_access_key()
aws_secret_access_key = config.get_aws_access_secret()
aws_access_key_id = self._config.get_aws_access_key()
aws_secret_access_key = self._config.get_aws_access_secret()
self._aws_connection = S3Connection(aws_access_key_id, aws_secret_access_key, True)

if url:
self._url = url
s3url_parse = re.match(self._s3_matcher, self._url)
self._bucket_name = s3url_parse.group("BUCKET")
self._key = s3url_parse.group("KEY")


def download(self, local_full_fn):
s3url_parse = re.match(self._s3_matcher, self._url)
bucket_name = s3url_parse.group("BUCKET")
key = s3url_parse.group("KEY")
log.info("Start to download file {} from s3 bucket {} to {}".format(
key, bucket_name, local_full_fn))
log.info(f"Start to download file {self._key} from s3 bucket {self._bucket_name} to {local_full_fn}")
if not self.validate_source():
log.error(f'Invalid url: {self._url}. Skip downloading.')
return Status.FAILED

try:
filekey = self._aws_connection.get_bucket(bucket_name).get_key(key)
filekey = self._aws_connection.get_bucket(self._bucket_name).get_key(self._key)
if filekey is None:
log.error("s3 key {} not found".format(key))
log.error("s3 key {} not found".format(self._key))
return Status.FAILED

filekey.get_contents_to_filename(local_full_fn)
etag = filekey.etag
if "-" not in etag:
if etag.startswith('"') and etag.endswith('"'):
etag = etag[1:-1]

md5 = self.md5_file(local_full_fn)
if md5 != etag:
log.error("MD5 verification failed. tarball is corrupt.")
Expand All @@ -71,3 +75,10 @@ def download(self, local_full_fn):
except Exception:
log.error("Failed to get package from s3: {}".format(traceback.format_exc()))
return Status.FAILED

def validate_source(self):
allow_list = self._config.get_s3_download_allow_list()
tags = {'type': 's3', 'url': self._url, 'bucket' : self._bucket_name}
create_sc_increment(DOWNLOAD_VALIDATE_METRICS, tags=tags)

return self._bucket_name in allow_list if allow_list else True
20 changes: 17 additions & 3 deletions deploy-agent/tests/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
load("@rules_python//python:defs.bzl", "py_library")

py_test(
name = "test_base_client",
name = "test_base_client",
srcs = ['unit/deploy/client/test_base_client.py'],
deps = ["test_lib"],
python_version = "PY3",
)

py_test(
name = "test_serverless_client",
name = "test_serverless_client",
srcs = ['unit/deploy/client/test_serverless_client.py'],
deps = ["test_lib"],
python_version = "PY3",
)

py_test(
name = "test_agent",
name = "test_agent",
srcs = ['unit/deploy/server/test_agent.py'],
deps = ["test_lib"],
python_version = "PY3",
Expand All @@ -28,6 +28,20 @@ py_test(
python_version = "PY3",
)

py_test(
name = "test_s3_download_helper",
srcs = ['unit/deploy/download/test_s3_download_helper.py'],
deps = ["test_lib"],
python_version = "PY3",
)

py_test(
name = "test_http_download_helper",
srcs = ['unit/deploy/download/test_http_download_helper.py'],
deps = ["test_lib"],
python_version = "PY3",
)

py_library(
name = "test_lib",
srcs = ["__init__.py"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2016 Pinterest, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from deployd.download.http_download_helper import HTTPDownloadHelper
import unittest
from unittest import mock
import logging

logger = logging.getLogger()
logger.level = logging.DEBUG


class TestHttpDownloadHelper(unittest.TestCase):

def setUp(self):
self.downloader = HTTPDownloadHelper(url="https://deploy1.com")

@mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list')
def test_validate_url_with_allow_list(self, mock_get_http_download_allow_list):
mock_get_http_download_allow_list.return_value = ['deploy1.com', 'deploy2.com']
result = self.downloader.validate_source()
self.assertTrue(result)
mock_get_http_download_allow_list.assert_called_once()

@mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list')
def test_validate_url_with_non_https(self, mock_get_http_download_allow_list):
downloader = HTTPDownloadHelper(url="http://deploy1.com")
mock_get_http_download_allow_list.return_value = ['deploy1', 'deploy2']
result = downloader.validate_source()

self.assertFalse(result)
mock_get_http_download_allow_list.assert_not_called()

@mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list')
def test_validate_url_without_allow_list(self, mock_get_http_download_allow_list):
mock_get_http_download_allow_list.return_value = None
result = self.downloader.validate_source()

self.assertTrue(result)
mock_get_http_download_allow_list.assert_called_once()

if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2016 Pinterest, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from deployd.download.s3_download_helper import S3DownloadHelper
import unittest
from unittest import mock
import logging

logger = logging.getLogger()
logger.level = logging.DEBUG


class TestS3DownloadHelper(unittest.TestCase):
@mock.patch('deployd.download.s3_download_helper.Config.get_aws_access_key')
@mock.patch('deployd.download.s3_download_helper.Config.get_aws_access_secret')
def setUp(self, mock_aws_key, mock_aws_secret):
mock_aws_key.return_value = "test_key"
mock_aws_secret.return_value= "test_secret"
self.downloader = S3DownloadHelper(local_full_fn='', url="s3://bucket1/key1")

@mock.patch('deployd.download.s3_download_helper.Config.get_s3_download_allow_list')
def test_validate_url_with_allow_list(self, mock_get_s3_download_allow_list):
mock_get_s3_download_allow_list.return_value = ['bucket1', 'bucket2', 'bucket3']
result = self.downloader.validate_source()
self.assertTrue(result)
mock_get_s3_download_allow_list.assert_called_once()

mock_get_s3_download_allow_list.return_value = ['bucket3']
result = self.downloader.validate_source()
self.assertFalse(result)

@mock.patch('deployd.download.s3_download_helper.Config.get_s3_download_allow_list')
def test_validate_url_without_allow_list(self, mock_get_s3_download_allow_list):
mock_get_s3_download_allow_list.return_value = None
result = self.downloader.validate_source()

self.assertTrue(result)
mock_get_s3_download_allow_list.assert_called_once()


if __name__ == '__main__':
unittest.main()
Loading