Skip to content

Commit

Permalink
RDS: allow snapshot ARNs for copy and restore (#8372)
Browse files Browse the repository at this point in the history
copy_db_snapshot() and restore_db_instance_from_db_snapshot() should accept
both names and full ARNs as snapshot identifier. For copy_db_snapshot(), only
the SourceDBSnapshotIdentifier can be an ARN, the TargetDBSnapshotIdentifier
has to be a name.
  • Loading branch information
snordhausen authored Dec 11, 2024
1 parent 3866eb0 commit 7583336
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
23 changes: 23 additions & 0 deletions moto/rds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,10 @@ def copy_db_snapshot(
tags: Optional[List[Dict[str, str]]] = None,
copy_tags: bool = False,
) -> DBSnapshot:
if source_snapshot_identifier.startswith("arn:aws:rds:"):
source_snapshot_identifier = self.extract_snapshot_name_from_arn(
source_snapshot_identifier
)
if source_snapshot_identifier not in self.database_snapshots:
raise DBSnapshotNotFoundError(source_snapshot_identifier)

Expand Down Expand Up @@ -2010,9 +2014,28 @@ def modify_db_instance(
def reboot_db_instance(self, db_instance_identifier: str) -> DBInstance:
return self.describe_db_instances(db_instance_identifier)[0]

def extract_snapshot_name_from_arn(self, snapshot_arn: str) -> str:
arn_breakdown = snapshot_arn.split(":")
region_name, account_id, resource_type, snapshot_name = arn_breakdown[3:7]
if resource_type != "snapshot":
raise InvalidParameterValue(
"The parameter SourceDBSnapshotIdentifier is not a valid identifier. "
"Identifiers must begin with a letter; must contain only ASCII "
"letters, digits, and hyphens; and must not end with a hyphen or "
"contain two consecutive hyphens."
)
if region_name != self.region_name or account_id != self.account_id:
raise NotImplementedError(
"Cross account/region snapshot handling is not yet implemented in moto."
)
return snapshot_name

def restore_db_instance_from_db_snapshot(
self, from_snapshot_id: str, overrides: Dict[str, Any]
) -> DBInstance:
if from_snapshot_id.startswith("arn:aws:rds:"):
from_snapshot_id = self.extract_snapshot_name_from_arn(from_snapshot_id)

snapshot = self.describe_db_snapshots(
db_instance_identifier=None, db_snapshot_identifier=from_snapshot_id
)[0]
Expand Down
39 changes: 35 additions & 4 deletions tests/test_rds/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,8 +1141,13 @@ def test_create_db_snapshots_with_tags():


@pytest.mark.parametrize("delete_db_instance", [True, False])
@pytest.mark.parametrize(
"db_snapshot_identifier",
("snapshot-1", f"arn:aws:rds:{DEFAULT_REGION}:123456789012:snapshot:snapshot-1"),
ids=("by_name", "by_arn"),
)
@mock_aws
def test_copy_db_snapshots(delete_db_instance: bool):
def test_copy_db_snapshots(delete_db_instance: bool, db_snapshot_identifier: str):
conn = boto3.client("rds", region_name=DEFAULT_REGION)

conn.create_db_instance(
Expand All @@ -1166,7 +1171,8 @@ def test_copy_db_snapshots(delete_db_instance: bool):
conn.delete_db_instance(DBInstanceIdentifier="db-primary-1")

target_snapshot = conn.copy_db_snapshot(
SourceDBSnapshotIdentifier="snapshot-1", TargetDBSnapshotIdentifier="snapshot-2"
SourceDBSnapshotIdentifier=db_snapshot_identifier,
TargetDBSnapshotIdentifier="snapshot-2",
).get("DBSnapshot")

assert target_snapshot.get("Engine") == "postgres"
Expand All @@ -1176,6 +1182,21 @@ def test_copy_db_snapshots(delete_db_instance: bool):
assert result["TagList"] == []


@mock_aws
def test_copy_db_snapshot_invalid_arns():
conn = boto3.client("rds", region_name=DEFAULT_REGION)

invalid_arn = (
f"arn:aws:rds:{DEFAULT_REGION}:123456789012:this-is-not-a-snapshot:snapshot-1"
)
with pytest.raises(ClientError) as ex:
conn.copy_db_snapshot(
SourceDBSnapshotIdentifier=invalid_arn,
TargetDBSnapshotIdentifier="snapshot-2",
)
assert "is not a valid identifier" in ex.value.response["Error"]["Message"]


original_snapshot_tags = [{"Key": "original", "Value": "snapshot tags"}]
new_snapshot_tags = [{"Key": "new", "Value": "tag"}]

Expand Down Expand Up @@ -1321,11 +1342,18 @@ def test_delete_db_snapshot():
conn.describe_db_snapshots(DBSnapshotIdentifier="snapshot-1")


@pytest.mark.parametrize(
"db_snapshot_identifier",
("snapshot-1", f"arn:aws:rds:{DEFAULT_REGION}:123456789012:snapshot:snapshot-1"),
ids=("by_name", "by_arn"),
)
@pytest.mark.parametrize(
"custom_db_subnet_group", [True, False], ids=("custom_subnet", "default_subnet")
)
@mock_aws
def test_restore_db_instance_from_db_snapshot(custom_db_subnet_group: bool):
def test_restore_db_instance_from_db_snapshot(
db_snapshot_identifier: str, custom_db_subnet_group: bool
):
conn = boto3.client("rds", region_name=DEFAULT_REGION)
conn.create_db_instance(
DBInstanceIdentifier="db-primary-1",
Expand Down Expand Up @@ -1353,9 +1381,12 @@ def test_restore_db_instance_from_db_snapshot(custom_db_subnet_group: bool):
)

# restore
new_instance = conn.restore_db_instance_from_db_snapshot(
DBInstanceIdentifier="db-restore-1", DBSnapshotIdentifier=db_snapshot_identifier
)["DBInstance"]
kwargs = {
"DBInstanceIdentifier": "db-restore-1",
"DBSnapshotIdentifier": "snapshot-1",
"DBSnapshotIdentifier": db_snapshot_identifier,
}
if custom_db_subnet_group:
kwargs["DBSubnetGroupName"] = "custom-subnet-group"
Expand Down

0 comments on commit 7583336

Please sign in to comment.