From 5f25dccc78eca3e3c4b982c9574ef70385524029 Mon Sep 17 00:00:00 2001 From: Kathia Barahona Date: Mon, 20 Jan 2025 14:26:38 +0100 Subject: [PATCH] Improve replication object clean-up on failure. Added some improvements on error handling whenever pg_migrate is cleaning any replication object (publication, subscription, replication slot). As our current implementation does not drop objects whenever create_ fails (raises an exception). Current clean-up is dependent on the object name returned by those functions, but if nothing is returned then nothing will be cleaned. The improvement makes the cleanup independent from the return of such functions, as it nows fetches the replication object name and verifies it exists or not, instead of relying on return values. --- aiven_db_migrate/migrate/pgmigrate.py | 200 ++++++++++++++++++-------- aiven_db_migrate/migrate/version.py | 2 +- test/conftest.py | 10 +- test/test_pg_migrate.py | 1 - test/test_pg_replication.py | 23 ++- test/test_table_filtering.py | 4 +- 6 files changed, 162 insertions(+), 78 deletions(-) diff --git a/aiven_db_migrate/migrate/pgmigrate.py b/aiven_db_migrate/migrate/pgmigrate.py index 1997435..089d7e9 100644 --- a/aiven_db_migrate/migrate/pgmigrate.py +++ b/aiven_db_migrate/migrate/pgmigrate.py @@ -35,6 +35,15 @@ MAX_CLI_LEN = 2097152 # getconf ARG_MAX +class ReplicationObjectType(enum.Enum): + PUBLICATION = "pub" + SUBSCRIPTION = "sub" + REPLICATION_SLOT = "slot" + + def get_display_name(self) -> str: + return self.name.replace("_", " ").lower() + + @dataclass class PGExtension: name: str @@ -97,6 +106,7 @@ class PGRole: class PGCluster: """PGCluster is a collection of databases managed by a single PostgreSQL server instance""" + DB_OBJECT_PREFIX = "managed_db_migrate" conn_info: Dict[str, Any] _databases: Dict[str, PGDatabase] _params: Dict[str, str] @@ -438,6 +448,10 @@ def mangle_db_name(self, db_name: str) -> str: return db_name return hashlib.md5(db_name.encode()).hexdigest() + def get_replication_object_name(self, dbname: str, replication_obj_type: ReplicationObjectType) -> str: + mangled_name = self.mangle_db_name(dbname) + return f"{self.DB_OBJECT_PREFIX}_{mangled_name}_{replication_obj_type.value}" + class PGSource(PGCluster): """Source PostgreSQL cluster""" @@ -454,8 +468,10 @@ def get_size(self, *, dbname, only_tables: Optional[List[str]] = None) -> float: return result[0]["size"] or 0 def create_publication(self, *, dbname: str, only_tables: Optional[List[str]] = None) -> str: - mangled_name = self.mangle_db_name(dbname) - pubname = f"managed_db_migrate_{mangled_name}_pub" + pubname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.PUBLICATION, + ) validate_pg_identifier_length(pubname) pub_options: Union[List[str], str] @@ -498,8 +514,11 @@ def create_publication(self, *, dbname: str, only_tables: Optional[List[str]] = return pubname def create_replication_slot(self, *, dbname: str) -> str: - mangled_name = self.mangle_db_name(dbname) - slotname = f"managed_db_migrate_{mangled_name}_slot" + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) + validate_pg_identifier_length(slotname) self.log.info("Creating replication slot %r in database %r", slotname, dbname) @@ -516,12 +535,17 @@ def create_replication_slot(self, *, dbname: str) -> str: return slotname - def get_publication(self, *, dbname: str, pubname: str) -> Dict[str, Any]: + def get_publication(self, *, dbname: str) -> Dict[str, Any]: + pubname = self.get_replication_object_name(dbname=dbname, replication_obj_type=ReplicationObjectType.PUBLICATION) # publications as per database so connect to given database result = self.c("SELECT * FROM pg_catalog.pg_publication WHERE pubname = %s", args=(pubname, ), dbname=dbname) return result[0] if result else {} - def get_replication_slot(self, *, dbname: str, slotname: str) -> Dict[str, Any]: + def get_replication_slot(self, *, dbname: str) -> Dict[str, Any]: + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) result = self.c( "SELECT * from pg_catalog.pg_replication_slots WHERE database = %s AND slot_name = %s", args=( @@ -532,7 +556,11 @@ def get_replication_slot(self, *, dbname: str, slotname: str) -> Dict[str, Any]: ) return result[0] if result else {} - def replication_in_sync(self, *, dbname: str, slotname: str, max_replication_lag: int) -> Tuple[bool, str]: + def replication_in_sync(self, *, dbname: str, max_replication_lag: int) -> Tuple[bool, str]: + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) exists = self.c( "SELECT 1 FROM pg_catalog.pg_replication_slots WHERE slot_name = %s", args=(slotname, ), dbname=dbname ) @@ -573,28 +601,63 @@ def large_objects_present(self, *, dbname: str) -> bool: self.log.warning("Unable to determine if large objects present in database %r", dbname) return False - def cleanup(self, *, dbname: str, pubname: str, slotname: str): - # publications as per database so connect to correct database - pub = self.get_publication(dbname=dbname, pubname=pubname) - if pub: - self.log.info("Dropping publication %r from database %r", pub, dbname) - self.c("DROP PUBLICATION {}".format(pub["pubname"]), dbname=dbname, return_rows=0) - slot = self.get_replication_slot(dbname=dbname, slotname=slotname) - if slot: - self.log.info("Dropping replication slot %r from database %r", slot, dbname) - self.c( - "SELECT 1 FROM pg_catalog.pg_drop_replication_slot(%s)", - args=(slot["slot_name"], ), - dbname=dbname, - return_rows=0 + def cleanup(self, *, dbname: str): + self._cleanup_replication_object(dbname=dbname, replication_object_type=ReplicationObjectType.PUBLICATION) + self._cleanup_replication_object(dbname=dbname, replication_object_type=ReplicationObjectType.REPLICATION_SLOT) + + def _cleanup_replication_object(self, dbname: str, replication_object_type: ReplicationObjectType): + rep_obj_type_display_name = replication_object_type.get_display_name() + + rep_obj_name = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=replication_object_type, + ) + try: + if ReplicationObjectType.PUBLICATION is replication_object_type: + rep_obj = self.get_publication(dbname=dbname) + delete_query = f"DROP PUBLICATION {rep_obj_name};" + args = () + else: + rep_obj = self.get_replication_slot(dbname=dbname) + delete_query = f"SELECT 1 FROM pg_catalog.pg_drop_replication_slot(%s)" + args = (rep_obj_name, ) + + if not rep_obj: + return + + self.log.info( + "Dropping %r %r from database %r", + rep_obj_type_display_name, + rep_obj_name, + dbname, + ) + self.c(delete_query, args=args, dbname=dbname, return_rows=0) + except Exception as exc: + self.log.error( + "Failed to drop %r %r for database %r: %s", + rep_obj_type_display_name, + rep_obj_name, + dbname, + exc, ) class PGTarget(PGCluster): """Target PostgreSQL cluster""" - def create_subscription(self, *, conn_str: str, pubname: str, slotname: str, dbname: str) -> str: - mangled_name = self.mangle_db_name(dbname) - subname = f"managed_db_migrate_{mangled_name}_sub" + def create_subscription(self, *, conn_str: str, dbname: str) -> str: + pubname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.PUBLICATION, + ) + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) + + subname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) validate_pg_identifier_length(subname) has_aiven_extras = self.has_aiven_extras(dbname=dbname) @@ -630,7 +693,11 @@ def create_subscription(self, *, conn_str: str, pubname: str, slotname: str, dbn return subname - def get_subscription(self, *, dbname: str, subname: str) -> Dict[str, Any]: + def get_subscription(self, *, dbname: str) -> Dict[str, Any]: + subname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) if self.has_aiven_extras(dbname=dbname): result = self.c( "SELECT * FROM aiven_extras.pg_list_all_subscriptions() WHERE subname = %s", args=(subname, ), dbname=dbname @@ -640,7 +707,11 @@ def get_subscription(self, *, dbname: str, subname: str) -> Dict[str, Any]: result = self.c("SELECT * FROM pg_catalog.pg_subscription WHERE subname = %s", args=(subname, ), dbname=dbname) return result[0] if result else {} - def replication_in_sync(self, *, dbname: str, subname: str, write_lsn: str, max_replication_lag: int) -> bool: + def replication_in_sync(self, *, dbname: str, write_lsn: str, max_replication_lag: int) -> bool: + subname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) status = self.c( """ SELECT stat.*, @@ -664,23 +735,32 @@ def replication_in_sync(self, *, dbname: str, subname: str, write_lsn: str, max_ self.log.warning("Replication status not available for %r in database %r", subname, dbname) return False - def cleanup(self, *, dbname: str, subname: str): - sub = self.get_subscription(dbname=dbname, subname=subname) - if sub: - self.log.info("Dropping subscription %r from database %r", sub["subname"], dbname) + def cleanup(self, *, dbname: str): + subname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) + try: + if not self.get_subscription(dbname=dbname): + return + + self.log.info("Dropping subscription %r from database %r", subname, dbname) if self.has_aiven_extras(dbname=dbname): # NOTE: this drops also replication slot from source - self.c( - "SELECT * FROM aiven_extras.pg_drop_subscription(%s)", - args=(sub["subname"], ), - dbname=dbname, - return_rows=0 - ) + self.c("SELECT * FROM aiven_extras.pg_drop_subscription(%s)", args=(subname, ), dbname=dbname, return_rows=0) else: # requires superuser or superuser-like privileges, such as "rds_replication" role in AWS RDS - self.c("ALTER SUBSCRIPTION {} DISABLE".format(sub["subname"]), dbname=dbname, return_rows=0) - self.c("ALTER SUBSCRIPTION {} SET (slot_name = NONE)".format(sub["subname"]), dbname=dbname, return_rows=0) - self.c("DROP SUBSCRIPTION {}".format(sub["subname"]), dbname=dbname, return_rows=0) + self.c("ALTER SUBSCRIPTION {} DISABLE".format(subname), dbname=dbname, return_rows=0) + self.c("ALTER SUBSCRIPTION {} SET (slot_name = NONE)".format(subname), dbname=dbname, return_rows=0) + self.c("DROP SUBSCRIPTION {}".format(subname), dbname=dbname, return_rows=0) + + except Exception as exc: + self.log.error( + "Failed to drop subscription %r for database %r: %s", + subname, + dbname, + exc, + ) @enum.unique @@ -1256,42 +1336,48 @@ def _dump_data(self, *, db: PGDatabase) -> PGMigrateStatus: raise PGDataDumpFailedError(f"Failed to dump data: {subtask!r}") return PGMigrateStatus.done - def _wait_for_replication(self, *, dbname: str, slotname: str, subname: str, check_interval: float = 2.0): + def _wait_for_replication(self, *, dbname: str, check_interval: float = 2.0): + slotname = self.source.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) + subname = self.target.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) + while True: - in_sync, write_lsn = self.source.replication_in_sync( - dbname=dbname, slotname=slotname, max_replication_lag=self.max_replication_lag - ) + in_sync, write_lsn = self.source.replication_in_sync(dbname=dbname, max_replication_lag=self.max_replication_lag) if in_sync and self.target.replication_in_sync( - dbname=dbname, subname=subname, write_lsn=write_lsn, max_replication_lag=self.max_replication_lag + dbname=dbname, write_lsn=write_lsn, max_replication_lag=self.max_replication_lag ): break time.sleep(check_interval) def _db_replication(self, *, db: PGDatabase) -> PGMigrateStatus: dbname = db.dbname - pubname = slotname = subname = None try: tables = self.filter_tables(db) or [] - pubname = self.source.create_publication(dbname=dbname, only_tables=tables) - slotname = self.source.create_replication_slot(dbname=dbname) - subname = self.target.create_subscription( - conn_str=self.source.conn_str(dbname=dbname), pubname=pubname, slotname=slotname, dbname=dbname - ) + self.source.create_publication(dbname=dbname, only_tables=tables) + self.source.create_replication_slot(dbname=dbname) + self.target.create_subscription(conn_str=self.source.conn_str(dbname=dbname), dbname=dbname) + except psycopg2.ProgrammingError as e: self.log.error("Encountered error: %r, cleaning up", e) - if subname: - self.target.cleanup(dbname=dbname, subname=subname) - if pubname and slotname: - self.source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname) + + # clean-up replication objects, avoid leaving traces specially in source + self.target.cleanup(dbname=dbname) + self.source.cleanup(dbname=dbname) raise self.log.info("Logical replication setup successful for database %r", dbname) if self.max_replication_lag > -1: - self._wait_for_replication(dbname=dbname, slotname=slotname, subname=subname) + self._wait_for_replication(dbname=dbname) if self.stop_replication: - self.target.cleanup(dbname=dbname, subname=subname) - self.source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname) + self.target.cleanup(dbname=dbname) + self.source.cleanup(dbname=dbname) return PGMigrateStatus.done + # leaving replication running return PGMigrateStatus.running diff --git a/aiven_db_migrate/migrate/version.py b/aiven_db_migrate/migrate/version.py index c8811fb..147c1e4 100644 --- a/aiven_db_migrate/migrate/version.py +++ b/aiven_db_migrate/migrate/version.py @@ -1 +1 @@ -__version__ = "0.1.5-2-gfd83d7c" +__version__ = "0.1.5-2-ga13c553" diff --git a/test/conftest.py b/test/conftest.py index 6c3815a..cf7ef88 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,7 +4,7 @@ from _pytest.fixtures import FixtureRequest from _pytest.tmpdir import TempPathFactory -from aiven_db_migrate.migrate.pgmigrate import PGTarget +from aiven_db_migrate.migrate.pgmigrate import PGTarget, ReplicationObjectType from contextlib import contextmanager from copy import copy from functools import partial, wraps @@ -154,8 +154,12 @@ def _drop_replication_slot(pg_runner_: PGRunner, slot_name: str) -> None: break # Found it, no need to try other databases. @wraps(function) - def wrapper(self: PGTarget, *args, slotname: str, **kwargs) -> R: - subname = function(self, *args, slotname=slotname, **kwargs) + def wrapper(self: PGTarget, *args, dbname: str, **kwargs) -> R: + subname = function(self, *args, dbname=dbname, **kwargs) + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) pg_runner.cleanups.append(partial(_drop_replication_slot, pg_runner_=pg_runner, slot_name=slotname)) diff --git a/test/test_pg_migrate.py b/test/test_pg_migrate.py index e0bc865..18e84bd 100644 --- a/test/test_pg_migrate.py +++ b/test/test_pg_migrate.py @@ -338,7 +338,6 @@ def test_migrate_source_aiven_extras(self, createdb: bool): result: PGMigrateResult = pg_mig.migrate() - assert len(result.pg_databases) == 2 self.assert_result( result=result.pg_databases[dbname], dbname=dbname, diff --git a/test/test_pg_replication.py b/test/test_pg_replication.py index 12397cd..0429e31 100644 --- a/test/test_pg_replication.py +++ b/test/test_pg_replication.py @@ -45,19 +45,18 @@ def test_replication(pg_source_and_target: Tuple[PGRunner, PGRunner], aiven_extr pubname = pg_source.create_publication(dbname=dbname) slotname = pg_source.create_replication_slot(dbname=dbname) # verify that pub and replication slot exixts - pub = pg_source.get_publication(dbname=dbname, pubname=pubname) + pub = pg_source.get_publication(dbname=dbname) assert pub assert pub["pubname"] == pubname - slot = pg_source.get_replication_slot(dbname=dbname, slotname=slotname) + slot = pg_source.get_replication_slot(dbname=dbname) assert slot assert slot["slot_name"] == slotname assert slot["slot_type"] == "logical" - subname = pg_target.create_subscription( - conn_str=pg_source.conn_str(dbname=dbname), pubname=pubname, slotname=slotname, dbname=dbname - ) + conn_str = pg_source.conn_str(dbname=dbname) + subname = pg_target.create_subscription(conn_str=conn_str, dbname=dbname) # verify that sub exists - sub = pg_target.get_subscription(dbname=dbname, subname=subname) + sub = pg_target.get_subscription(dbname=dbname) assert sub assert sub["subname"] == subname assert sub["subenabled"] @@ -69,10 +68,8 @@ def test_replication(pg_source_and_target: Tuple[PGRunner, PGRunner], aiven_extr # wait until replication is in sync timer = Timer(timeout=10, what="replication in sync") while timer.loop(): - in_sync, write_lsn = pg_source.replication_in_sync(dbname=dbname, slotname=slotname, max_replication_lag=0) - if in_sync and pg_target.replication_in_sync( - dbname=dbname, subname=subname, write_lsn=write_lsn, max_replication_lag=0 - ): + in_sync, write_lsn = pg_source.replication_in_sync(dbname=dbname, max_replication_lag=0) + if in_sync and pg_target.replication_in_sync(dbname=dbname, write_lsn=write_lsn, max_replication_lag=0): break # verify that all data has been replicated @@ -82,8 +79,8 @@ def test_replication(pg_source_and_target: Tuple[PGRunner, PGRunner], aiven_extr if int(count["count"]) == 5: break - pg_target.cleanup(dbname=dbname, subname=subname) - pg_source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname) + pg_target.cleanup(dbname=dbname) + pg_source.cleanup(dbname=dbname) # verify that pub, replication slot and sub are dropped assert not source.list_pubs(dbname=dbname) @@ -113,7 +110,7 @@ def test_replication_no_aiven_extras_no_superuser(pg_source_and_target: Tuple[PG # creating subscription should fail with insufficient privilege with pytest.raises(psycopg2.ProgrammingError) as err: - pg_target.create_subscription(conn_str=pg_source.conn_str(), pubname="dummy", slotname="dummy", dbname=dbname) + pg_target.create_subscription(conn_str=pg_source.conn_str(), dbname=dbname) assert err.value.pgcode == psycopg2.errorcodes.INSUFFICIENT_PRIVILEGE privilege_error_message = "must be superuser to create subscriptions" diff --git a/test/test_table_filtering.py b/test/test_table_filtering.py index 9ef61e8..72fec4a 100644 --- a/test/test_table_filtering.py +++ b/test/test_table_filtering.py @@ -213,8 +213,6 @@ def test_replicate_filter_with(pg_source_and_target: Tuple[PGRunner, PGRunner], except psycopg2.Error: pass try: - pg_mig.source.cleanup( - dbname=db, pubname=f"managed_db_migrate_{db}_pub", slotname=f"managed_db_migrate_{db}_slot" - ) + pg_mig.source.cleanup(dbname=db) except: # pylint: disable=bare-except pass