diff --git a/moto/rds/models.py b/moto/rds/models.py index c4825b6a9292..8aae7e86d4aa 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -1890,6 +1890,10 @@ def copy_db_snapshot( tags: Optional[List[Dict[str, str]]] = None, copy_tags: bool = False, ) -> DBSnapshot: + if source_snapshot_identifier.startswith("arn:aws:rds:"): + source_snapshot_identifier = self.extract_snapshot_name_from_arn( + source_snapshot_identifier + ) if source_snapshot_identifier not in self.database_snapshots: raise DBSnapshotNotFoundError(source_snapshot_identifier) @@ -2010,9 +2014,28 @@ def modify_db_instance( def reboot_db_instance(self, db_instance_identifier: str) -> DBInstance: return self.describe_db_instances(db_instance_identifier)[0] + def extract_snapshot_name_from_arn(self, snapshot_arn: str) -> str: + arn_breakdown = snapshot_arn.split(":") + region_name, account_id, resource_type, snapshot_name = arn_breakdown[3:7] + if resource_type != "snapshot": + raise InvalidParameterValue( + "The parameter SourceDBSnapshotIdentifier is not a valid identifier. " + "Identifiers must begin with a letter; must contain only ASCII " + "letters, digits, and hyphens; and must not end with a hyphen or " + "contain two consecutive hyphens." + ) + if region_name != self.region_name or account_id != self.account_id: + raise NotImplementedError( + "Cross account/region snapshot handling is not yet implemented in moto." + ) + return snapshot_name + def restore_db_instance_from_db_snapshot( self, from_snapshot_id: str, overrides: Dict[str, Any] ) -> DBInstance: + if from_snapshot_id.startswith("arn:aws:rds:"): + from_snapshot_id = self.extract_snapshot_name_from_arn(from_snapshot_id) + snapshot = self.describe_db_snapshots( db_instance_identifier=None, db_snapshot_identifier=from_snapshot_id )[0] diff --git a/tests/test_rds/test_rds.py b/tests/test_rds/test_rds.py index 2c7fc5fd1ee4..36dddcadffc4 100644 --- a/tests/test_rds/test_rds.py +++ b/tests/test_rds/test_rds.py @@ -1141,8 +1141,13 @@ def test_create_db_snapshots_with_tags(): @pytest.mark.parametrize("delete_db_instance", [True, False]) +@pytest.mark.parametrize( + "db_snapshot_identifier", + ("snapshot-1", f"arn:aws:rds:{DEFAULT_REGION}:123456789012:snapshot:snapshot-1"), + ids=("by_name", "by_arn"), +) @mock_aws -def test_copy_db_snapshots(delete_db_instance: bool): +def test_copy_db_snapshots(delete_db_instance: bool, db_snapshot_identifier: str): conn = boto3.client("rds", region_name=DEFAULT_REGION) conn.create_db_instance( @@ -1166,7 +1171,8 @@ def test_copy_db_snapshots(delete_db_instance: bool): conn.delete_db_instance(DBInstanceIdentifier="db-primary-1") target_snapshot = conn.copy_db_snapshot( - SourceDBSnapshotIdentifier="snapshot-1", TargetDBSnapshotIdentifier="snapshot-2" + SourceDBSnapshotIdentifier=db_snapshot_identifier, + TargetDBSnapshotIdentifier="snapshot-2", ).get("DBSnapshot") assert target_snapshot.get("Engine") == "postgres" @@ -1176,6 +1182,21 @@ def test_copy_db_snapshots(delete_db_instance: bool): assert result["TagList"] == [] +@mock_aws +def test_copy_db_snapshot_invalid_arns(): + conn = boto3.client("rds", region_name=DEFAULT_REGION) + + invalid_arn = ( + f"arn:aws:rds:{DEFAULT_REGION}:123456789012:this-is-not-a-snapshot:snapshot-1" + ) + with pytest.raises(ClientError) as ex: + conn.copy_db_snapshot( + SourceDBSnapshotIdentifier=invalid_arn, + TargetDBSnapshotIdentifier="snapshot-2", + ) + assert "is not a valid identifier" in ex.value.response["Error"]["Message"] + + original_snapshot_tags = [{"Key": "original", "Value": "snapshot tags"}] new_snapshot_tags = [{"Key": "new", "Value": "tag"}] @@ -1321,11 +1342,18 @@ def test_delete_db_snapshot(): conn.describe_db_snapshots(DBSnapshotIdentifier="snapshot-1") +@pytest.mark.parametrize( + "db_snapshot_identifier", + ("snapshot-1", f"arn:aws:rds:{DEFAULT_REGION}:123456789012:snapshot:snapshot-1"), + ids=("by_name", "by_arn"), +) @pytest.mark.parametrize( "custom_db_subnet_group", [True, False], ids=("custom_subnet", "default_subnet") ) @mock_aws -def test_restore_db_instance_from_db_snapshot(custom_db_subnet_group: bool): +def test_restore_db_instance_from_db_snapshot( + db_snapshot_identifier: str, custom_db_subnet_group: bool +): conn = boto3.client("rds", region_name=DEFAULT_REGION) conn.create_db_instance( DBInstanceIdentifier="db-primary-1", @@ -1353,9 +1381,12 @@ def test_restore_db_instance_from_db_snapshot(custom_db_subnet_group: bool): ) # restore + new_instance = conn.restore_db_instance_from_db_snapshot( + DBInstanceIdentifier="db-restore-1", DBSnapshotIdentifier=db_snapshot_identifier + )["DBInstance"] kwargs = { "DBInstanceIdentifier": "db-restore-1", - "DBSnapshotIdentifier": "snapshot-1", + "DBSnapshotIdentifier": db_snapshot_identifier, } if custom_db_subnet_group: kwargs["DBSubnetGroupName"] = "custom-subnet-group"