Skip to content

Commit

Permalink
RDS: Allow more db-instance operations in custom subnets (#8440)
Browse files Browse the repository at this point in the history
  • Loading branch information
snordhausen authored Dec 27, 2024
1 parent a56d269 commit cbc7a51
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 19 deletions.
16 changes: 8 additions & 8 deletions moto/rds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2005,7 +2005,7 @@ def modify_db_instance(
"You must specify apply immediately when rotating the master user password.",
)
database.update(db_kwargs)
initial_state = copy.deepcopy(database)
initial_state = copy.copy(database)
database.master_user_secret_status = (
"active" # already set the final state in the background
)
Expand Down Expand Up @@ -2065,14 +2065,14 @@ def restore_db_instance_to_point_in_time(
db_instance_identifier=source_db_identifier
)[0]

# remove the db subnet group as it cannot be copied
# and is not used in the restored instance
source_dict = db_instance.__dict__
if "db_subnet_group" in source_dict:
del source_dict["db_subnet_group"]
new_instance_props = {}
for key, value in db_instance.__dict__.items():
# Remove backend / db subnet group as they cannot be copied
# and are not used in the restored instance.
if key in ("backend", "db_subnet_group"):
continue
new_instance_props[key] = copy.deepcopy(value)

new_instance_props = copy.deepcopy(source_dict)
new_instance_props.pop("backend")
if not db_instance.option_group_supplied:
# If the option group is not supplied originally, the 'option_group_name' will receive a default value
# Force this reconstruction, and prevent any validation on the default value
Expand Down
49 changes: 38 additions & 11 deletions tests/test_rds/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,18 @@ def test_describe_non_existent_database():
conn.describe_db_instances(DBInstanceIdentifier="not-a-db")


@pytest.mark.parametrize(
"custom_db_subnet_group", [True, False], ids=("custom_subnet", "default_subnet")
)
@mock_aws
def test_modify_db_instance():
def test_modify_db_instance(custom_db_subnet_group: bool):
conn = boto3.client("rds", region_name=DEFAULT_REGION)

if custom_db_subnet_group:
extra_kwargs = {"DBSubnetGroupName": create_db_subnet_group()}
else:
extra_kwargs = {}

conn.create_db_instance(
DBInstanceIdentifier="db-id",
AllocatedStorage=10,
Expand All @@ -563,6 +572,7 @@ def test_modify_db_instance():
MasterUserPassword="hunter2",
Port=1234,
DBSecurityGroups=["my_sg"],
**extra_kwargs,
)
inst = conn.describe_db_instances(DBInstanceIdentifier="db-id")["DBInstances"][0]
assert inst["AllocatedStorage"] == 10
Expand Down Expand Up @@ -1372,13 +1382,7 @@ def test_restore_db_instance_from_db_snapshot(
)

if custom_db_subnet_group:
ec2 = boto3.client("ec2", region_name=DEFAULT_REGION)
subnets = ec2.describe_subnets()["Subnets"]
conn.create_db_subnet_group(
DBSubnetGroupName="custom-subnet-group",
DBSubnetGroupDescription="xxx",
SubnetIds=[subnets[0]["SubnetId"]],
)
db_subnet_group_name = create_db_subnet_group()

# restore
new_instance = conn.restore_db_instance_from_db_snapshot(
Expand All @@ -1389,11 +1393,11 @@ def test_restore_db_instance_from_db_snapshot(
"DBSnapshotIdentifier": db_snapshot_identifier,
}
if custom_db_subnet_group:
kwargs["DBSubnetGroupName"] = "custom-subnet-group"
kwargs["DBSubnetGroupName"] = db_subnet_group_name
new_instance = conn.restore_db_instance_from_db_snapshot(**kwargs)["DBInstance"]
if custom_db_subnet_group:
assert (
new_instance["DBSubnetGroup"]["DBSubnetGroupName"] == "custom-subnet-group"
new_instance["DBSubnetGroup"]["DBSubnetGroupName"] == db_subnet_group_name
)
assert new_instance["DBInstanceIdentifier"] == "db-restore-1"
assert new_instance["DBInstanceClass"] == "db.m1.small"
Expand All @@ -1420,9 +1424,18 @@ def test_restore_db_instance_from_db_snapshot(
)


@pytest.mark.parametrize(
"custom_db_subnet_group", [True, False], ids=("custom_subnet", "default_subnet")
)
@mock_aws
def test_restore_db_instance_to_point_in_time():
def test_restore_db_instance_to_point_in_time(custom_db_subnet_group: bool):
conn = boto3.client("rds", region_name=DEFAULT_REGION)

if custom_db_subnet_group:
extra_kwargs = {"DBSubnetGroupName": create_db_subnet_group()}
else:
extra_kwargs = {}

conn.create_db_instance(
DBInstanceIdentifier="db-primary-1",
AllocatedStorage=10,
Expand All @@ -1432,6 +1445,7 @@ def test_restore_db_instance_to_point_in_time():
MasterUsername="root",
MasterUserPassword="hunter2",
DBSecurityGroups=["my_sg"],
**extra_kwargs,
)
assert len(conn.describe_db_instances()["DBInstances"]) == 1

Expand Down Expand Up @@ -2321,6 +2335,19 @@ def test_create_database_in_subnet_group():
)


def create_db_subnet_group(db_subnet_group_name: str = "custom_db_subnet") -> str:
ec2 = boto3.client("ec2", region_name=DEFAULT_REGION)
first_subnet_id = ec2.describe_subnets()["Subnets"][0]["SubnetId"]

rds = boto3.client("rds", region_name=DEFAULT_REGION)
rds.create_db_subnet_group(
DBSubnetGroupName=db_subnet_group_name,
DBSubnetGroupDescription="xxx",
SubnetIds=[first_subnet_id],
)
return db_subnet_group_name


@mock_aws
def test_describe_database_subnet_group():
vpc_conn = boto3.client("ec2", DEFAULT_REGION)
Expand Down

0 comments on commit cbc7a51

Please sign in to comment.