diff --git a/src/sbosc/controller/validator.py b/src/sbosc/controller/validator.py index 8a97943..f790d82 100644 --- a/src/sbosc/controller/validator.py +++ b/src/sbosc/controller/validator.py @@ -6,7 +6,7 @@ import MySQLdb from MySQLdb.cursors import Cursor -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generator, List from modules.db import Database from sbosc.exceptions import StopFlagSet @@ -39,7 +39,7 @@ def __init__(self, controller: 'Controller'): def set_stop_flag(self): self.stop_flag = True - def handle_operational_error(self, e, range_queue, start_range, end_range): + def __handle_operational_error(self, e, range_queue, start_range, end_range): if e.args[0] == 2013: self.logger.warning("Query timeout. Retry with smaller batch size") range_queue.put((start_range, start_range + (end_range - start_range) // 2)) @@ -50,7 +50,7 @@ def handle_operational_error(self, e, range_queue, start_range, end_range): range_queue.put((start_range, end_range)) time.sleep(3) - def validate_bulk_import_batch(self, range_queue: Queue, failed_pks): + def __validate_bulk_import_batch(self, range_queue: Queue, failed_pks): with self.source_conn_pool.get_connection() as source_conn, self.dest_conn_pool.get_connection() as dest_conn: while not range_queue.empty(): if len(failed_pks) > 0: @@ -68,7 +68,7 @@ def validate_bulk_import_batch(self, range_queue: Queue, failed_pks): failed_pks.extend(not_imported_pks) return False except MySQLdb.OperationalError as e: - self.handle_operational_error(e, range_queue, batch_start_pk, batch_end_pk) + self.__handle_operational_error(e, range_queue, batch_start_pk, batch_end_pk) source_conn.ping(True) dest_conn.ping(True) continue @@ -83,7 +83,7 @@ def bulk_import_validation(self): metadata = self.redis_data.metadata range_queue = Queue() start_pk = 0 - while start_pk < metadata.max_id: + while start_pk <= metadata.max_id: range_queue.put((start_pk, min(start_pk + self.bulk_import_batch_size, metadata.max_id))) start_pk += self.bulk_import_batch_size + 1 failed_pks = [] @@ -91,7 +91,7 @@ def bulk_import_validation(self): with concurrent.futures.ThreadPoolExecutor(max_workers=self.thread_count) as executor: threads = [] for _ in range(self.thread_count): - threads.append(executor.submit(self.validate_bulk_import_batch, range_queue, failed_pks)) + threads.append(executor.submit(self.__validate_bulk_import_batch, range_queue, failed_pks)) is_valid = all([thread.result() for thread in threads]) if not is_valid: self.logger.critical(f"Failed to validate bulk import. Failed pks: {failed_pks}") @@ -99,7 +99,7 @@ def bulk_import_validation(self): self.logger.info("Bulk import validation succeeded") return is_valid - def get_timestamp_range(self): + def __get_timestamp_range(self): start_timestamp = None end_timestamp = None with self.db.cursor() as cursor: @@ -136,36 +136,30 @@ def get_timestamp_range(self): end_timestamp = cursor.fetchone()[0] return start_timestamp, end_timestamp - def execute_apply_dml_events_validation_query( - self, source_cursor, dest_cursor, table, start_timestamp, end_timestamp, unmatched_pks): + def __execute_apply_dml_events_validation_query( + self, source_cursor, dest_cursor, table, event_pks: list, unmatched_pks: list): metadata = self.redis_data.metadata if table == 'inserted_pk': - not_inserted_pks = self.migration_operation.get_not_inserted_pks( - source_cursor, dest_cursor, start_timestamp, end_timestamp) + not_inserted_pks = self.migration_operation.get_not_inserted_pks(source_cursor, dest_cursor, event_pks) if not_inserted_pks: self.logger.warning(f"Found {len(not_inserted_pks)} unmatched inserted pks: {not_inserted_pks}") unmatched_pks.extend([(pk, UnmatchType.NOT_UPDATED) for pk in not_inserted_pks]) elif table == 'updated_pk': - not_updated_pks = self.migration_operation.get_not_updated_pks( - source_cursor, dest_cursor, start_timestamp, end_timestamp) + not_updated_pks = self.migration_operation.get_not_updated_pks(source_cursor, dest_cursor, event_pks) if not_updated_pks: self.logger.warning(f"Found {len(not_updated_pks)} unmatched updated pks: {not_updated_pks}") unmatched_pks.extend([(pk, UnmatchType.NOT_UPDATED) for pk in not_updated_pks]) elif table == 'deleted_pk': - source_cursor.execute(f''' - SELECT source_pk FROM {config.SBOSC_DB}.deleted_pk_{self.migration_id} - WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp} - ''') - if source_cursor.rowcount > 0: - target_pks = ','.join([str(row[0]) for row in source_cursor.fetchall()]) + if event_pks: + event_pks_str = ','.join([str(pk) for pk in event_pks]) dest_cursor.execute(f''' - SELECT id FROM {metadata.destination_db}.{metadata.destination_table} WHERE id IN ({target_pks}) + SELECT id FROM {metadata.destination_db}.{metadata.destination_table} WHERE id IN ({event_pks_str}) ''') not_deleted_pks = set([row[0] for row in dest_cursor.fetchall()]) if dest_cursor.rowcount > 0: # Check if deleted pks are reinserted source_cursor.execute(f''' - SELECT id FROM {metadata.source_db}.{metadata.source_table} WHERE id IN ({target_pks}) + SELECT id FROM {metadata.source_db}.{metadata.source_table} WHERE id IN ({event_pks_str}) ''') reinserted_pks = set([row[0] for row in source_cursor.fetchall()]) if reinserted_pks: @@ -174,7 +168,17 @@ def execute_apply_dml_events_validation_query( self.logger.warning(f"Found {len(not_deleted_pks)} unmatched deleted pks: {not_deleted_pks}") unmatched_pks.extend([(pk, UnmatchType.NOT_REMOVED) for pk in not_deleted_pks]) - def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_pks): + def __get_event_pk_batch(self, cursor, table, start_timestamp, end_timestamp) -> Generator[List[int], None, None]: + cursor.execute(f''' + SELECT source_pk FROM {config.SBOSC_DB}.{table}_{self.migration_id} + WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp} + ''') + event_pks = [row[0] for row in cursor.fetchall()] + while event_pks: + yield event_pks[:self.apply_dml_events_batch_size] + event_pks = event_pks[self.apply_dml_events_batch_size:] + + def __validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_pks): with self.source_conn_pool.get_connection() as source_conn, self.dest_conn_pool.get_connection() as dest_conn: while not range_queue.empty(): if self.stop_flag: @@ -182,7 +186,7 @@ def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_p try: batch_start_timestamp, batch_end_timestamp = range_queue.get_nowait() - self.logger.info(f"Validating DML events from {batch_start_timestamp} to {batch_end_timestamp}") + self.logger.info(f"Validating {table} from {batch_start_timestamp} to {batch_end_timestamp}") except Empty: self.logger.warning("Range queue is empty") continue @@ -198,7 +202,7 @@ def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_p WHERE event_timestamp BETWEEN {batch_start_timestamp} AND {batch_end_timestamp} ''') event_count = source_cursor.fetchone()[0] - if event_count > self.apply_dml_events_batch_size: + if event_count > self.apply_dml_events_batch_size and batch_end_timestamp > batch_start_timestamp: range_queue.put(( batch_start_timestamp, batch_start_timestamp + (batch_end_timestamp - batch_start_timestamp) // 2 @@ -211,17 +215,20 @@ def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_p else: try: - self.execute_apply_dml_events_validation_query( - source_cursor, dest_cursor, table, - batch_start_timestamp, batch_end_timestamp, unmatched_pks + event_pk_batch = self.__get_event_pk_batch( + source_cursor, table, batch_start_timestamp, batch_end_timestamp ) + while event_pks := next(event_pk_batch, None): + self.__execute_apply_dml_events_validation_query( + source_cursor, dest_cursor, table, event_pks, unmatched_pks + ) except MySQLdb.OperationalError as e: - self.handle_operational_error(e, range_queue, batch_start_timestamp, batch_end_timestamp) + self.__handle_operational_error(e, range_queue, batch_start_timestamp, batch_end_timestamp) source_conn.ping(True) dest_conn.ping(True) continue - def validate_unmatched_pks(self): + def __validate_unmatched_pks(self): self.logger.info("Validating unmatched pks") with self.db.cursor() as cursor: cursor: Cursor @@ -279,7 +286,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp): if table_rows > 0: range_queue = Queue() batch_start_timestamp = start_timestamp - while batch_start_timestamp < end_timestamp: + while batch_start_timestamp <= end_timestamp: batch_duration = \ (end_timestamp - start_timestamp) * self.apply_dml_events_batch_size // table_rows batch_end_timestamp = min(batch_start_timestamp + batch_duration, end_timestamp) @@ -290,7 +297,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp): threads = [] for _ in range(self.thread_count): threads.append(executor.submit( - self.validate_apply_dml_events_batch, table, range_queue, unmatched_pks)) + self.__validate_apply_dml_events_batch, table, range_queue, unmatched_pks)) for thread in threads: thread.result() @@ -298,7 +305,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp): INSERT IGNORE INTO {config.SBOSC_DB}.unmatched_rows (source_pk, migration_id, unmatch_type) VALUES (%s, {self.migration_id}, %s) ''', unmatched_pks) - self.validate_unmatched_pks() + self.__validate_unmatched_pks() cursor.execute( f"SELECT COUNT(1) FROM {config.SBOSC_DB}.unmatched_rows WHERE migration_id = {self.migration_id}") unmatched_rows = cursor.fetchone()[0] @@ -310,7 +317,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp): def apply_dml_events_validation(self): self.logger.info("Start apply DML events validation") - start_timestamp, end_timestamp = self.get_timestamp_range() + start_timestamp, end_timestamp = self.__get_timestamp_range() if start_timestamp is None: self.logger.warning("No events found. Skipping apply DML events validation") return True @@ -334,6 +341,10 @@ def full_dml_event_validation(self): """ :return: True if validation ran, False if validation skipped """ + if self.full_dml_event_validation_interval == 0: + self.logger.info("Full DML event validation is disabled") + return False + self.logger.info("Start full DML event validation") with self.db.cursor(role='reader') as cursor: @@ -347,7 +358,9 @@ def full_dml_event_validation(self): last_validation_time = cursor.fetchone()[0] if datetime.now() - last_validation_time < timedelta(hours=self.full_dml_event_validation_interval): self.logger.info( - "Last validation was done less than 1 hour ago. Skipping full DML event validation") + f"Last validation was done less than {self.full_dml_event_validation_interval} hour ago. " + f"Skipping full DML event validation" + ) return False cursor.execute(f''' diff --git a/src/sbosc/operations/base.py b/src/sbosc/operations/base.py index 98d1100..3ad0ebc 100644 --- a/src/sbosc/operations/base.py +++ b/src/sbosc/operations/base.py @@ -54,33 +54,33 @@ def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk): not_imported_pks = [row[0] for row in source_cursor.fetchall()] return not_imported_pks - def get_not_inserted_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp): + def get_not_inserted_pks(self, source_cursor, dest_cursor, event_pks): not_inserted_pks = [] - event_pks = self._get_event_pks(source_cursor, 'insert', start_timestamp, end_timestamp) if event_pks: + event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' SELECT source.id FROM {self.source_db}.{self.source_table} AS source LEFT JOIN {self.destination_db}.{self.destination_table} AS dest ON source.id = dest.id - WHERE source.id IN ({event_pks}) + WHERE source.id IN ({event_pks_str}) AND dest.id IS NULL ''') not_inserted_pks = [row[0] for row in source_cursor.fetchall()] return not_inserted_pks - def get_not_updated_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp): + def get_not_updated_pks(self, source_cursor, dest_cursor, event_pks): not_updated_pks = [] - event_pks = self._get_event_pks(source_cursor, 'update', start_timestamp, end_timestamp) if event_pks: + event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' SELECT combined.id FROM ( SELECT {self.source_columns}, 'source' AS table_type FROM {self.source_db}.{self.source_table} - WHERE id IN ({event_pks}) + WHERE id IN ({event_pks_str}) UNION ALL SELECT {self.source_columns}, 'destination' AS table_type FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({event_pks}) + WHERE id IN ({event_pks_str}) ) AS combined GROUP BY {self.source_columns} HAVING COUNT(1) = 1 AND SUM(table_type = 'source') = 1 @@ -190,30 +190,30 @@ def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk): dest_pks = [row[0] for row in dest_cursor.fetchall()] return list(set(source_pks) - set(dest_pks)) - def get_not_inserted_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp): + def get_not_inserted_pks(self, source_cursor, dest_cursor, event_pks): not_inserted_pks = [] - event_pks = self._get_event_pks(source_cursor, 'insert', start_timestamp, end_timestamp) if event_pks: - source_cursor.execute(f"SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({event_pks})") + event_pks_str = ','.join([str(pk) for pk in event_pks]) + source_cursor.execute(f"SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({event_pks_str})") source_pks = [row[0] for row in source_cursor.fetchall()] dest_cursor.execute( - f"SELECT id FROM {self.destination_db}.{self.destination_table} WHERE id IN ({event_pks})") + f"SELECT id FROM {self.destination_db}.{self.destination_table} WHERE id IN ({event_pks_str})") dest_pks = [row[0] for row in dest_cursor.fetchall()] not_inserted_pks = list(set(source_pks) - set(dest_pks)) return not_inserted_pks - def get_not_updated_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp): + def get_not_updated_pks(self, source_cursor, dest_cursor, event_pks): not_updated_pks = [] - event_pks = self._get_event_pks(source_cursor, 'update', start_timestamp, end_timestamp) if event_pks: + event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id IN ({event_pks}) + WHERE id IN ({event_pks_str}) ''') source_df = pd.DataFrame(source_cursor.fetchall(), columns=[c[0] for c in source_cursor.description]) dest_cursor.execute(f''' SELECT {self.source_columns} FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({event_pks}) + WHERE id IN ({event_pks_str}) ''') dest_df = pd.DataFrame(dest_cursor.fetchall(), columns=[c[0] for c in dest_cursor.description]) diff --git a/src/sbosc/operations/operation.py b/src/sbosc/operations/operation.py index 7c0d21b..0f47f5f 100644 --- a/src/sbosc/operations/operation.py +++ b/src/sbosc/operations/operation.py @@ -1,10 +1,9 @@ from abc import abstractmethod from contextlib import contextmanager -from typing import Literal +from typing import List from MySQLdb.cursors import Cursor -from config import config from modules.db import Database from modules.redis import RedisData @@ -48,20 +47,8 @@ def get_not_imported_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start """ pass - def _get_event_pks( - self, cursor: Cursor, event_type: Literal['insert', 'update'], start_timestamp, end_timestamp): - table_names = { - 'insert': f'inserted_pk_{self.migration_id}', - 'update': f'updated_pk_{self.migration_id}' - } - cursor.execute(f''' - SELECT source_pk FROM {config.SBOSC_DB}.{table_names[event_type]} - WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp} - ''') - return ','.join([str(row[0]) for row in cursor.fetchall()]) - @abstractmethod - def get_not_inserted_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start_timestamp, end_timestamp): + def get_not_inserted_pks(self, source_cursor: Cursor, dest_cursor: Cursor, event_pks: List[int]): """ Returns a list of primary keys that have not been inserted into the destination table. Used in APPLY_DML_EVENTS_VALIDATION stage to validate that all inserts have been applied. @@ -69,7 +56,7 @@ def get_not_inserted_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start pass @abstractmethod - def get_not_updated_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start_timestamp, end_timestamp): + def get_not_updated_pks(self, source_cursor: Cursor, dest_cursor: Cursor, event_pks: List[int]): """ Returns a list of primary keys that have not been updated in the destination table. Used in APPLY_DML_EVENTS_VALIDATION stage to validate that all updates have been applied. diff --git a/tests/test_controller.py b/tests/test_controller.py index 905ee71..11d56a5 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -189,10 +189,10 @@ def test_apply_dml_events_validation(controller: Controller, setup_table, redis_ cursor.execute(f"TRUNCATE TABLE {config.SBOSC_DB}.event_handler_status") cursor.execute(f"TRUNCATE TABLE {config.SBOSC_DB}.apply_dml_events_status") - # Event handler status doesn't have any row + # event_handler_status table doesn't have any row assert not controller.validator.apply_dml_events_validation() - # Insert row to event handler status and validate + # Insert row to event_handler_status table and validate cursor.execute(f''' INSERT INTO {config.SBOSC_DB}.event_handler_status (migration_id, log_file, log_pos, last_event_timestamp, created_at) VALUES (1, 'mysql-bin.000001', 4, {timestamp_range[1]}, NOW()) @@ -241,7 +241,7 @@ def test_apply_dml_events_validation(controller: Controller, setup_table, redis_ assert cursor.fetchone()[0] == 0 # Add new insert, update event - new_timestamp_range = (101, 200) + new_timestamp_range = (100, 200) new_insert_events = [ (random.randint(TABLE_SIZE, TABLE_SIZE * 2), random.randint(*new_timestamp_range)) for _ in range(500)] new_update_events = [(random.randint(1, TABLE_SIZE), random.randint(*new_timestamp_range)) for _ in range(500)] @@ -278,6 +278,32 @@ def test_apply_dml_events_validation(controller: Controller, setup_table, redis_ cursor.execute(f"SELECT COUNT(1) FROM {config.SBOSC_DB}.unmatched_rows") assert cursor.fetchone()[0] == 0 + # More records inserted than apply_dml_events_validation batch size in 1 second + large_insert_events = { + (random.randint(TABLE_SIZE * 2, TABLE_SIZE * 3), 201) for _ in range(2000)} + cursor.executemany(f''' + INSERT IGNORE INTO {config.SBOSC_DB}.inserted_pk_1 (source_pk, event_timestamp) VALUES (%s, %s) + ''', large_insert_events) + cursor.executemany(f''' + INSERT IGNORE INTO {config.SOURCE_DB}.{config.SOURCE_TABLE} (id, A, B, C) VALUES (%s, %s, %s, %s) + ''', [(i[0], 'a', 'b', 'c') for i in large_insert_events]) + cursor.execute(f''' + INSERT INTO {config.SBOSC_DB}.event_handler_status (migration_id, log_file, log_pos, last_event_timestamp, created_at) + VALUES (1, 'mysql-bin.000001', 4, 201, NOW()) + ''') + controller.validator.apply_dml_events_validation() + cursor.execute(f"SELECT COUNT(1) FROM {config.SBOSC_DB}.unmatched_rows") + assert cursor.fetchone()[0] == len(large_insert_events) + + # Apply changes to destination table + cursor.executemany(f''' + INSERT IGNORE INTO {config.DESTINATION_DB}.{config.DESTINATION_TABLE} (id, A, B, C) VALUES (%s, %s, %s, %s) + ''', [(i[0], 'a', 'b', 'c') for i in large_insert_events]) + + # requires 2 iterations to check all unmatched rows + controller.validator.apply_dml_events_validation() + assert controller.validator.apply_dml_events_validation() + # Test full validation assert controller.validator.full_dml_event_validation() cursor.execute(f"SELECT is_valid FROM {config.SBOSC_DB}.full_dml_event_validation_status")