Skip to content

Commit

Permalink
Merge pull request #10 from sendbird/bugfix/improve-apply-dml-events-…
Browse files Browse the repository at this point in the history
…validation-batch-size-handling

Improve apply_dml_events_validation batch size handling
  • Loading branch information
jjh-kim authored Jul 25, 2024
2 parents 9cef334 + 8f7ed15 commit 58c8fa4
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 68 deletions.
81 changes: 47 additions & 34 deletions src/sbosc/controller/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import MySQLdb
from MySQLdb.cursors import Cursor

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generator, List

from modules.db import Database
from sbosc.exceptions import StopFlagSet
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, controller: 'Controller'):
def set_stop_flag(self):
self.stop_flag = True

def handle_operational_error(self, e, range_queue, start_range, end_range):
def __handle_operational_error(self, e, range_queue, start_range, end_range):
if e.args[0] == 2013:
self.logger.warning("Query timeout. Retry with smaller batch size")
range_queue.put((start_range, start_range + (end_range - start_range) // 2))
Expand All @@ -50,7 +50,7 @@ def handle_operational_error(self, e, range_queue, start_range, end_range):
range_queue.put((start_range, end_range))
time.sleep(3)

def validate_bulk_import_batch(self, range_queue: Queue, failed_pks):
def __validate_bulk_import_batch(self, range_queue: Queue, failed_pks):
with self.source_conn_pool.get_connection() as source_conn, self.dest_conn_pool.get_connection() as dest_conn:
while not range_queue.empty():
if len(failed_pks) > 0:
Expand All @@ -68,7 +68,7 @@ def validate_bulk_import_batch(self, range_queue: Queue, failed_pks):
failed_pks.extend(not_imported_pks)
return False
except MySQLdb.OperationalError as e:
self.handle_operational_error(e, range_queue, batch_start_pk, batch_end_pk)
self.__handle_operational_error(e, range_queue, batch_start_pk, batch_end_pk)
source_conn.ping(True)
dest_conn.ping(True)
continue
Expand All @@ -83,23 +83,23 @@ def bulk_import_validation(self):
metadata = self.redis_data.metadata
range_queue = Queue()
start_pk = 0
while start_pk < metadata.max_id:
while start_pk <= metadata.max_id:
range_queue.put((start_pk, min(start_pk + self.bulk_import_batch_size, metadata.max_id)))
start_pk += self.bulk_import_batch_size + 1
failed_pks = []

with concurrent.futures.ThreadPoolExecutor(max_workers=self.thread_count) as executor:
threads = []
for _ in range(self.thread_count):
threads.append(executor.submit(self.validate_bulk_import_batch, range_queue, failed_pks))
threads.append(executor.submit(self.__validate_bulk_import_batch, range_queue, failed_pks))
is_valid = all([thread.result() for thread in threads])
if not is_valid:
self.logger.critical(f"Failed to validate bulk import. Failed pks: {failed_pks}")
else:
self.logger.info("Bulk import validation succeeded")
return is_valid

def get_timestamp_range(self):
def __get_timestamp_range(self):
start_timestamp = None
end_timestamp = None
with self.db.cursor() as cursor:
Expand Down Expand Up @@ -136,36 +136,30 @@ def get_timestamp_range(self):
end_timestamp = cursor.fetchone()[0]
return start_timestamp, end_timestamp

def execute_apply_dml_events_validation_query(
self, source_cursor, dest_cursor, table, start_timestamp, end_timestamp, unmatched_pks):
def __execute_apply_dml_events_validation_query(
self, source_cursor, dest_cursor, table, event_pks: list, unmatched_pks: list):
metadata = self.redis_data.metadata
if table == 'inserted_pk':
not_inserted_pks = self.migration_operation.get_not_inserted_pks(
source_cursor, dest_cursor, start_timestamp, end_timestamp)
not_inserted_pks = self.migration_operation.get_not_inserted_pks(source_cursor, dest_cursor, event_pks)
if not_inserted_pks:
self.logger.warning(f"Found {len(not_inserted_pks)} unmatched inserted pks: {not_inserted_pks}")
unmatched_pks.extend([(pk, UnmatchType.NOT_UPDATED) for pk in not_inserted_pks])
elif table == 'updated_pk':
not_updated_pks = self.migration_operation.get_not_updated_pks(
source_cursor, dest_cursor, start_timestamp, end_timestamp)
not_updated_pks = self.migration_operation.get_not_updated_pks(source_cursor, dest_cursor, event_pks)
if not_updated_pks:
self.logger.warning(f"Found {len(not_updated_pks)} unmatched updated pks: {not_updated_pks}")
unmatched_pks.extend([(pk, UnmatchType.NOT_UPDATED) for pk in not_updated_pks])
elif table == 'deleted_pk':
source_cursor.execute(f'''
SELECT source_pk FROM {config.SBOSC_DB}.deleted_pk_{self.migration_id}
WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp}
''')
if source_cursor.rowcount > 0:
target_pks = ','.join([str(row[0]) for row in source_cursor.fetchall()])
if event_pks:
event_pks_str = ','.join([str(pk) for pk in event_pks])
dest_cursor.execute(f'''
SELECT id FROM {metadata.destination_db}.{metadata.destination_table} WHERE id IN ({target_pks})
SELECT id FROM {metadata.destination_db}.{metadata.destination_table} WHERE id IN ({event_pks_str})
''')
not_deleted_pks = set([row[0] for row in dest_cursor.fetchall()])
if dest_cursor.rowcount > 0:
# Check if deleted pks are reinserted
source_cursor.execute(f'''
SELECT id FROM {metadata.source_db}.{metadata.source_table} WHERE id IN ({target_pks})
SELECT id FROM {metadata.source_db}.{metadata.source_table} WHERE id IN ({event_pks_str})
''')
reinserted_pks = set([row[0] for row in source_cursor.fetchall()])
if reinserted_pks:
Expand All @@ -174,15 +168,25 @@ def execute_apply_dml_events_validation_query(
self.logger.warning(f"Found {len(not_deleted_pks)} unmatched deleted pks: {not_deleted_pks}")
unmatched_pks.extend([(pk, UnmatchType.NOT_REMOVED) for pk in not_deleted_pks])

def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_pks):
def __get_event_pk_batch(self, cursor, table, start_timestamp, end_timestamp) -> Generator[List[int], None, None]:
cursor.execute(f'''
SELECT source_pk FROM {config.SBOSC_DB}.{table}_{self.migration_id}
WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp}
''')
event_pks = [row[0] for row in cursor.fetchall()]
while event_pks:
yield event_pks[:self.apply_dml_events_batch_size]
event_pks = event_pks[self.apply_dml_events_batch_size:]

def __validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_pks):
with self.source_conn_pool.get_connection() as source_conn, self.dest_conn_pool.get_connection() as dest_conn:
while not range_queue.empty():
if self.stop_flag:
raise StopFlagSet()

try:
batch_start_timestamp, batch_end_timestamp = range_queue.get_nowait()
self.logger.info(f"Validating DML events from {batch_start_timestamp} to {batch_end_timestamp}")
self.logger.info(f"Validating {table} from {batch_start_timestamp} to {batch_end_timestamp}")
except Empty:
self.logger.warning("Range queue is empty")
continue
Expand All @@ -198,7 +202,7 @@ def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_p
WHERE event_timestamp BETWEEN {batch_start_timestamp} AND {batch_end_timestamp}
''')
event_count = source_cursor.fetchone()[0]
if event_count > self.apply_dml_events_batch_size:
if event_count > self.apply_dml_events_batch_size and batch_end_timestamp > batch_start_timestamp:
range_queue.put((
batch_start_timestamp,
batch_start_timestamp + (batch_end_timestamp - batch_start_timestamp) // 2
Expand All @@ -211,17 +215,20 @@ def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_p

else:
try:
self.execute_apply_dml_events_validation_query(
source_cursor, dest_cursor, table,
batch_start_timestamp, batch_end_timestamp, unmatched_pks
event_pk_batch = self.__get_event_pk_batch(
source_cursor, table, batch_start_timestamp, batch_end_timestamp
)
while event_pks := next(event_pk_batch, None):
self.__execute_apply_dml_events_validation_query(
source_cursor, dest_cursor, table, event_pks, unmatched_pks
)
except MySQLdb.OperationalError as e:
self.handle_operational_error(e, range_queue, batch_start_timestamp, batch_end_timestamp)
self.__handle_operational_error(e, range_queue, batch_start_timestamp, batch_end_timestamp)
source_conn.ping(True)
dest_conn.ping(True)
continue

def validate_unmatched_pks(self):
def __validate_unmatched_pks(self):
self.logger.info("Validating unmatched pks")
with self.db.cursor() as cursor:
cursor: Cursor
Expand Down Expand Up @@ -279,7 +286,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp):
if table_rows > 0:
range_queue = Queue()
batch_start_timestamp = start_timestamp
while batch_start_timestamp < end_timestamp:
while batch_start_timestamp <= end_timestamp:
batch_duration = \
(end_timestamp - start_timestamp) * self.apply_dml_events_batch_size // table_rows
batch_end_timestamp = min(batch_start_timestamp + batch_duration, end_timestamp)
Expand All @@ -290,15 +297,15 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp):
threads = []
for _ in range(self.thread_count):
threads.append(executor.submit(
self.validate_apply_dml_events_batch, table, range_queue, unmatched_pks))
self.__validate_apply_dml_events_batch, table, range_queue, unmatched_pks))
for thread in threads:
thread.result()

cursor.executemany(f'''
INSERT IGNORE INTO {config.SBOSC_DB}.unmatched_rows (source_pk, migration_id, unmatch_type)
VALUES (%s, {self.migration_id}, %s)
''', unmatched_pks)
self.validate_unmatched_pks()
self.__validate_unmatched_pks()
cursor.execute(
f"SELECT COUNT(1) FROM {config.SBOSC_DB}.unmatched_rows WHERE migration_id = {self.migration_id}")
unmatched_rows = cursor.fetchone()[0]
Expand All @@ -310,7 +317,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp):
def apply_dml_events_validation(self):
self.logger.info("Start apply DML events validation")

start_timestamp, end_timestamp = self.get_timestamp_range()
start_timestamp, end_timestamp = self.__get_timestamp_range()
if start_timestamp is None:
self.logger.warning("No events found. Skipping apply DML events validation")
return True
Expand All @@ -334,6 +341,10 @@ def full_dml_event_validation(self):
"""
:return: True if validation ran, False if validation skipped
"""
if self.full_dml_event_validation_interval == 0:
self.logger.info("Full DML event validation is disabled")
return False

self.logger.info("Start full DML event validation")

with self.db.cursor(role='reader') as cursor:
Expand All @@ -347,7 +358,9 @@ def full_dml_event_validation(self):
last_validation_time = cursor.fetchone()[0]
if datetime.now() - last_validation_time < timedelta(hours=self.full_dml_event_validation_interval):
self.logger.info(
"Last validation was done less than 1 hour ago. Skipping full DML event validation")
f"Last validation was done less than {self.full_dml_event_validation_interval} hour ago. "
f"Skipping full DML event validation"
)
return False

cursor.execute(f'''
Expand Down
30 changes: 15 additions & 15 deletions src/sbosc/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,33 @@ def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk):
not_imported_pks = [row[0] for row in source_cursor.fetchall()]
return not_imported_pks

def get_not_inserted_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp):
def get_not_inserted_pks(self, source_cursor, dest_cursor, event_pks):
not_inserted_pks = []
event_pks = self._get_event_pks(source_cursor, 'insert', start_timestamp, end_timestamp)
if event_pks:
event_pks_str = ','.join([str(pk) for pk in event_pks])
source_cursor.execute(f'''
SELECT source.id FROM {self.source_db}.{self.source_table} AS source
LEFT JOIN {self.destination_db}.{self.destination_table} AS dest ON source.id = dest.id
WHERE source.id IN ({event_pks})
WHERE source.id IN ({event_pks_str})
AND dest.id IS NULL
''')
not_inserted_pks = [row[0] for row in source_cursor.fetchall()]
return not_inserted_pks

def get_not_updated_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp):
def get_not_updated_pks(self, source_cursor, dest_cursor, event_pks):
not_updated_pks = []
event_pks = self._get_event_pks(source_cursor, 'update', start_timestamp, end_timestamp)
if event_pks:
event_pks_str = ','.join([str(pk) for pk in event_pks])
source_cursor.execute(f'''
SELECT combined.id
FROM (
SELECT {self.source_columns}, 'source' AS table_type
FROM {self.source_db}.{self.source_table}
WHERE id IN ({event_pks})
WHERE id IN ({event_pks_str})
UNION ALL
SELECT {self.source_columns}, 'destination' AS table_type
FROM {self.destination_db}.{self.destination_table}
WHERE id IN ({event_pks})
WHERE id IN ({event_pks_str})
) AS combined
GROUP BY {self.source_columns}
HAVING COUNT(1) = 1 AND SUM(table_type = 'source') = 1
Expand Down Expand Up @@ -190,30 +190,30 @@ def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk):
dest_pks = [row[0] for row in dest_cursor.fetchall()]
return list(set(source_pks) - set(dest_pks))

def get_not_inserted_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp):
def get_not_inserted_pks(self, source_cursor, dest_cursor, event_pks):
not_inserted_pks = []
event_pks = self._get_event_pks(source_cursor, 'insert', start_timestamp, end_timestamp)
if event_pks:
source_cursor.execute(f"SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({event_pks})")
event_pks_str = ','.join([str(pk) for pk in event_pks])
source_cursor.execute(f"SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({event_pks_str})")
source_pks = [row[0] for row in source_cursor.fetchall()]
dest_cursor.execute(
f"SELECT id FROM {self.destination_db}.{self.destination_table} WHERE id IN ({event_pks})")
f"SELECT id FROM {self.destination_db}.{self.destination_table} WHERE id IN ({event_pks_str})")
dest_pks = [row[0] for row in dest_cursor.fetchall()]
not_inserted_pks = list(set(source_pks) - set(dest_pks))
return not_inserted_pks

def get_not_updated_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp):
def get_not_updated_pks(self, source_cursor, dest_cursor, event_pks):
not_updated_pks = []
event_pks = self._get_event_pks(source_cursor, 'update', start_timestamp, end_timestamp)
if event_pks:
event_pks_str = ','.join([str(pk) for pk in event_pks])
source_cursor.execute(f'''
SELECT {self.source_columns} FROM {self.source_db}.{self.source_table}
WHERE id IN ({event_pks})
WHERE id IN ({event_pks_str})
''')
source_df = pd.DataFrame(source_cursor.fetchall(), columns=[c[0] for c in source_cursor.description])
dest_cursor.execute(f'''
SELECT {self.source_columns} FROM {self.destination_db}.{self.destination_table}
WHERE id IN ({event_pks})
WHERE id IN ({event_pks_str})
''')
dest_df = pd.DataFrame(dest_cursor.fetchall(), columns=[c[0] for c in dest_cursor.description])

Expand Down
19 changes: 3 additions & 16 deletions src/sbosc/operations/operation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import abstractmethod
from contextlib import contextmanager
from typing import Literal
from typing import List

from MySQLdb.cursors import Cursor

from config import config
from modules.db import Database
from modules.redis import RedisData

Expand Down Expand Up @@ -48,28 +47,16 @@ def get_not_imported_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start
"""
pass

def _get_event_pks(
self, cursor: Cursor, event_type: Literal['insert', 'update'], start_timestamp, end_timestamp):
table_names = {
'insert': f'inserted_pk_{self.migration_id}',
'update': f'updated_pk_{self.migration_id}'
}
cursor.execute(f'''
SELECT source_pk FROM {config.SBOSC_DB}.{table_names[event_type]}
WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp}
''')
return ','.join([str(row[0]) for row in cursor.fetchall()])

@abstractmethod
def get_not_inserted_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start_timestamp, end_timestamp):
def get_not_inserted_pks(self, source_cursor: Cursor, dest_cursor: Cursor, event_pks: List[int]):
"""
Returns a list of primary keys that have not been inserted into the destination table.
Used in APPLY_DML_EVENTS_VALIDATION stage to validate that all inserts have been applied.
"""
pass

@abstractmethod
def get_not_updated_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start_timestamp, end_timestamp):
def get_not_updated_pks(self, source_cursor: Cursor, dest_cursor: Cursor, event_pks: List[int]):
"""
Returns a list of primary keys that have not been updated in the destination table.
Used in APPLY_DML_EVENTS_VALIDATION stage to validate that all updates have been applied.
Expand Down
Loading

0 comments on commit 58c8fa4

Please sign in to comment.