Skip to content

Commit

Permalink
Update downloader to read allow lists from configs (#1484)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntascii authored Mar 2, 2024
1 parent 886cd11 commit 0d13631
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 25 deletions.
9 changes: 9 additions & 0 deletions deploy-agent/deployd/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ py_binary(
],
)

py_binary(
name = "deploy-downloader",
srcs = ["download/downloader.py"],
main = "download/downloader.py",
deps = [
":lib",
],
)

py_library(
name = "third_party",
deps = [
Expand Down
2 changes: 1 addition & 1 deletion deploy-agent/deployd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
# 2: puppet applied successfully with changes
PUPPET_SUCCESS_EXIT_CODES = [0, 2]

__version__ = '1.2.57'
__version__ = '1.2.58'
17 changes: 15 additions & 2 deletions deploy-agent/deployd/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import logging
import os
import json

from typing import Any, List, Optional

from deployd import __version__
Expand Down Expand Up @@ -292,8 +294,19 @@ def get_stage_type_key(self) -> Optional[str]:
def get_facter_account_id_key(self) -> str:
return self.get_var('account_id_key', 'ec2_metadata.identity-credentials.ec2.info')

def _get_download_allow_list(self, key: str) -> List:
allow_list_str = self.get_var(key, '[]')
allow_list = []
try:
allow_list = json.loads(allow_list_str)
except json.JSONDecodeError:
log.error(f"Error: The string {allow_list_str} could not be converted to a list.")
return allow_list


def get_http_download_allow_list(self) -> List:
return self.get_var('http_download_allow_list', [])
return self._get_download_allow_list('http_download_allow_list')


def get_s3_download_allow_list(self) -> List:
return self.get_var('s3_download_allow_list', [])
return self._get_download_allow_list('s3_download_allow_list')
17 changes: 9 additions & 8 deletions deploy-agent/deployd/download/download_helper_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,34 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from logging import Logger, getLogger
from typing import Optional

from future.moves.urllib.parse import urlparse
from boto.s3.connection import S3Connection

from deployd.download.download_helper import DownloadHelper
from deployd.download.s3_download_helper import S3DownloadHelper
from deployd.download.http_download_helper import HTTPDownloadHelper
from deployd.download.local_download_helper import LocalDownloadHelper


log = logging.getLogger(__name__)
log: Logger = getLogger(__name__)


class DownloadHelperFactory(object):

@staticmethod
def gen_downloader(url, config) -> Optional[S3DownloadHelper|LocalDownloadHelper|HTTPDownloadHelper]:
def gen_downloader(url, config) -> Optional[DownloadHelper]:
url_parse = urlparse(url)
if url_parse.scheme == 's3':
aws_access_key_id = config.get_aws_access_key()
Expand All @@ -38,8 +39,8 @@ def gen_downloader(url, config) -> Optional[S3DownloadHelper|LocalDownloadHelper
log.error("aws access key id and secret access key not found")
return None
aws_conn = S3Connection(aws_access_key_id, aws_secret_access_key, True)
return S3DownloadHelper(url, aws_conn)
return S3DownloadHelper(local_full_fn=url, aws_connection=aws_conn, url=None, config=config)
elif url_parse.scheme == 'file':
return LocalDownloadHelper(url)
return LocalDownloadHelper(url=url)
else:
return HTTPDownloadHelper(url)
return HTTPDownloadHelper(url=url, config=config)
6 changes: 5 additions & 1 deletion deploy-agent/deployd/download/http_download_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,8 @@ def validate_source(self) -> bool:

allow_list = self._config.get_http_download_allow_list()

return domain in allow_list if allow_list else True
if not allow_list or domain in allow_list:
return True
else:
log.error(f"{domain} is not in the allow list: {allow_list}.")
return False
10 changes: 7 additions & 3 deletions deploy-agent/deployd/download/s3_download_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@

class S3DownloadHelper(DownloadHelper):

def __init__(self, local_full_fn, aws_connection=None, url=None) -> None:
def __init__(self, local_full_fn, aws_connection=None, url=None, config=None) -> None:
super(S3DownloadHelper, self).__init__(local_full_fn)
self._s3_matcher = "^s3://(?P<BUCKET>[a-zA-Z0-9\-_]+)/(?P<KEY>[a-zA-Z0-9\-_/\.]+)/?"
self._config = Config()
self._config = config if config else Config()
if aws_connection:
self._aws_connection = aws_connection
else:
Expand Down Expand Up @@ -81,4 +81,8 @@ def validate_source(self) -> bool:
tags = {'type': 's3', 'url': self._url, 'bucket' : self._bucket_name}
create_sc_increment(DOWNLOAD_VALIDATE_METRICS, tags=tags)

return self._bucket_name in allow_list if allow_list else True
if not allow_list or self._bucket_name in allow_list:
return True
else:
log.error(f"{self._bucket_name} is not in the allow list: {allow_list}.")
return False
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from deployd.download.http_download_helper import HTTPDownloadHelper
from deployd.common.config import Config
import unittest
from unittest import mock
import logging
Expand All @@ -24,16 +25,17 @@
class TestHttpDownloadHelper(unittest.TestCase):

def setUp(self):
self.downloader = HTTPDownloadHelper(url="https://deploy1.com")
self.config = Config()
self.downloader = HTTPDownloadHelper(url="https://deploy1.com", config=self.config)

@mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list')
@mock.patch('deployd.common.config.Config.get_http_download_allow_list')
def test_validate_url_with_allow_list(self, mock_get_http_download_allow_list):
mock_get_http_download_allow_list.return_value = ['deploy1.com', 'deploy2.com']
result = self.downloader.validate_source()
self.assertTrue(result)
mock_get_http_download_allow_list.assert_called_once()

@mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list')
@mock.patch('deployd.common.config.Config.get_http_download_allow_list')
def test_validate_url_with_non_https(self, mock_get_http_download_allow_list):
downloader = HTTPDownloadHelper(url="http://deploy1.com")
mock_get_http_download_allow_list.return_value = ['deploy1', 'deploy2']
Expand All @@ -42,9 +44,9 @@ def test_validate_url_with_non_https(self, mock_get_http_download_allow_list):
self.assertFalse(result)
mock_get_http_download_allow_list.assert_not_called()

@mock.patch('deployd.download.http_download_helper.Config.get_http_download_allow_list')
@mock.patch('deployd.common.config.Config.get_http_download_allow_list')
def test_validate_url_without_allow_list(self, mock_get_http_download_allow_list):
mock_get_http_download_allow_list.return_value = None
mock_get_http_download_allow_list.return_value = []
result = self.downloader.validate_source()

self.assertTrue(result)
Expand Down
18 changes: 13 additions & 5 deletions deploy-agent/tests/unit/deploy/download/test_s3_download_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest
from unittest import mock
import logging

from deployd.common.config import Config
logger = logging.getLogger()
logger.level = logging.DEBUG

Expand All @@ -27,9 +27,10 @@ class TestS3DownloadHelper(unittest.TestCase):
def setUp(self, mock_aws_key, mock_aws_secret):
mock_aws_key.return_value = "test_key"
mock_aws_secret.return_value= "test_secret"
self.downloader = S3DownloadHelper(local_full_fn='', url="s3://bucket1/key1")
self.config = Config()
self.downloader = S3DownloadHelper(local_full_fn='', aws_connection=None, url="s3://bucket1/key1", config=self.config)

@mock.patch('deployd.download.s3_download_helper.Config.get_s3_download_allow_list')
@mock.patch('deployd.common.config.Config.get_s3_download_allow_list')
def test_validate_url_with_allow_list(self, mock_get_s3_download_allow_list):
mock_get_s3_download_allow_list.return_value = ['bucket1', 'bucket2', 'bucket3']
result = self.downloader.validate_source()
Expand All @@ -40,14 +41,21 @@ def test_validate_url_with_allow_list(self, mock_get_s3_download_allow_list):
result = self.downloader.validate_source()
self.assertFalse(result)

@mock.patch('deployd.download.s3_download_helper.Config.get_s3_download_allow_list')
@mock.patch('deployd.common.config.Config.get_s3_download_allow_list')
def test_validate_url_without_allow_list(self, mock_get_s3_download_allow_list):
mock_get_s3_download_allow_list.return_value = None
mock_get_s3_download_allow_list.return_value = []
result = self.downloader.validate_source()

self.assertTrue(result)
mock_get_s3_download_allow_list.assert_called_once()

@mock.patch('deployd.common.config.Config.get_s3_download_allow_list')
def test_validate_url_without_allow_list(self, mock_get_s3_download_allow_list):
mock_get_s3_download_allow_list.return_value = []
result = self.downloader.validate_source()

self.assertTrue(result)
mock_get_s3_download_allow_list.assert_called_once()

if __name__ == '__main__':
unittest.main()

0 comments on commit 0d13631

Please sign in to comment.