Skip to content
This repository has been archived by the owner on Oct 4, 2024. It is now read-only.

feat(MWAA): adds a check for DNS resolution in a VPC used by MWAA #213

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MWAA/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ This script requires permission to the following API calls:
- [ec2:DescribeSecurityGroups](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeSecurityGroups.html)
- [ec2:DescribeSubnets](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeSubnets.html)
- [ec2:DescribeVpcEndpoints](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeVpcEndpoints.html)
- [ec2:DescribeVpcAttribute](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeVpcAttribute.html)
- [airflow:GetEnvironment](https://docs.aws.amazon.com/mwaa/latest/userguide/mwaa-actions-resources.html)
- [s3:GetBucketPublicAccessBlock](https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetPublicAccessBlock.html)
- [logs:DescribeLogGroups](https://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_DescribeLogGroups.html)
Expand Down
69 changes: 56 additions & 13 deletions MWAA/tests/test_verify_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pytest

from moto import mock_s3
from moto import mock_s3, mock_s3control, mock_ec2
from verify_env import verify_env

@pytest.fixture
Expand Down Expand Up @@ -41,31 +41,31 @@ def test_validation_region():
https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/
'''
regions = [
'us-east-2',
'us-east-1',
'us-west-2',
'ap-northeast-1',
'ap-northeast-2',
'ap-south-1',
'ap-southeast-1',
'ap-southeast-2',
'ap-northeast-1',
'ca-central-1',
'eu-central-1',
'eu-north-1',
'eu-west-1',
'eu-north-1'
'eu-west-2',
'eu-west-3',
'sa-east-1',
'us-east-1',
'us-east-2',
'us-west-2'
]
for region in regions:
assert verify_env.validation_region(region) == region
unsupport_regions = [
'us-west-1',
'af-south-1',
'ap-east-1',
'ap-south-1',
'ap-northeast-3',
'ap-northeast-2',
'ca-central-1',
'eu-west-2',
'eu-south-1',
'eu-west-3',
'me-sourth-1',
'sa-east-1'
'me-south-1',
]
for unsupport_region in unsupport_regions:
with pytest.raises(argparse.ArgumentTypeError) as excinfo:
Expand Down Expand Up @@ -250,6 +250,7 @@ def init_s3():
permisions
'''
@mock_s3
@mock_s3control
def _init_s3(is_bucket_access_blocked, is_account_access_blocked):
s3_client = boto3.client('s3', region_name=TEST_ACCOUNT_REGION)
s3_client.create_bucket(Bucket=TEST_BUCKET_NAME)
Expand Down Expand Up @@ -290,3 +291,45 @@ def test_s3_public_access_block(init_s3, env_info, capfd, is_bucket_access_block
out, _ = capfd.readouterr()

assert expected.format(bucket_arn=TEST_BUCKET_ARN) in out


VPC_DNS_TEST_CASES = [
# DNS resolution enabled
(True, "DNS resolution enabled"),
# DNS resolution disabled
(False, "DNS resolution is disabled")
]


@pytest.fixture(scope="function")
def init_vpcs():
@mock_ec2
def _init_vpcs(dns_support_enabled):
ec2_client = boto3.client("ec2", region_name=TEST_ACCOUNT_REGION)
vpc = ec2_client.create_vpc(
CidrBlock="10.10.0.0/16"
)

if dns_support_enabled is not None:
ec2_client.modify_vpc_attribute(
VpcId=vpc.get("Vpc").get("VpcId"),
EnableDnsSupport={
"Value": dns_support_enabled
}
)

return vpc.get("Vpc").get("VpcId")

return _init_vpcs


@mock_ec2
@pytest.mark.parametrize("dns_support_enabled, expected_string", VPC_DNS_TEST_CASES)
def test_check_vpc_dns_resolution(init_vpcs, capfd, dns_support_enabled, expected_string):
vpc_id = init_vpcs(dns_support_enabled)
ec2_client = boto3.client("ec2", region_name=TEST_ACCOUNT_REGION)

verify_env.check_vpc_dns_resolution(vpc_id, ec2_client)
out, _ = capfd.readouterr()

assert expected_string in out
23 changes: 22 additions & 1 deletion MWAA/verify_env/verify_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,25 @@ def check_service_vpc_endpoints(ec2_client, subnets):
print("The route for the subnets do not have a NAT Gateway. However, there are sufficient VPC endpoints")


def check_vpc_dns_resolution(vpc_id, ec2_client):
"""
Checks whether DNS resolution is enabled in the VPC that MWAA uses. DNS resolution is critical
when resolving AWS service names; without it the webserver for example can't connect to the scheduler.
Originally encountered in
https://apache-airflow.slack.com/archives/CCRR5EBA7/p1673956637610969?thread_ts=1673539899.621079&cid=CCRR5EBA7
"""
print("### Verifying DNS resolution is enabled in VPC ID %s..." % (vpc_id,))
dns_support = ec2_client.describe_vpc_attribute(
Attribute="enableDnsSupport",
VpcId=vpc_id
).get("EnableDnsSupport").get("Value")

if dns_support:
print("VPC ID %s has DNS resolution enabled ✅\n" % (vpc_id,))
else:
print("DNS resolution is disabled in VPC ID %s 🚫\n" % (vpc_id,))


def check_routes(input_env, input_subnets, input_subnet_ids, ec2_client):
'''
method to check and make sure routes have access to the internet if public and subnets are private
Expand Down Expand Up @@ -972,14 +991,16 @@ def get_mwaa_utilized_services(ec2_client, vpc):
ssm = boto3.client('ssm', region_name=REGION)
iam = boto3.client('iam', region_name=REGION)
env, subnets, subnet_ids = prompt_user_and_print_info(ENV_NAME, ec2)
vpc_id = subnets[0]['VpcId']
check_iam_permissions(env, iam)
check_kms_key_policy(env, kms)
log_groups = check_log_groups(env, ENV_NAME, logs, cloudtrail)
check_nacl(subnets, subnet_ids, ec2)
check_routes(env, subnets, subnet_ids, ec2)
check_vpc_dns_resolution(vpc_id=vpc_id, ec2_client=ec2)
check_s3_block_public_access(env, s3, s3control)
check_security_groups(env, ec2)
mwaa_services = get_mwaa_utilized_services(ec2, subnets[0]['VpcId'])
mwaa_services = get_mwaa_utilized_services(ec2, vpc_id)
check_connectivity_to_dep_services(env, subnets, ec2, ssm, mwaa_services)
check_for_failing_logs(log_groups, logs)
except ClientError as client_error:
Expand Down