From e8f5d7680ff14342b2ed46cc0b8c3bd4463fa3c2 Mon Sep 17 00:00:00 2001 From: "Michael S. Molina" <70410625+michael-s-molina@users.noreply.github.com> Date: Fri, 16 Aug 2024 08:39:36 -0400 Subject: [PATCH] fix: upgrade_catalog_perms and downgrade_catalog_perms implementation (#29860) --- superset/migrations/shared/catalogs.py | 360 ++++++++++++++++++------- 1 file changed, 269 insertions(+), 91 deletions(-) diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index b09c71739f8b7..b75214291b0d9 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -18,7 +18,8 @@ from __future__ import annotations import logging -from typing import Any, Type +from datetime import datetime +from typing import Any, Type, Union import sqlalchemy as sa from alembic import op @@ -35,8 +36,7 @@ ) from superset.models.core import Database -logger = logging.getLogger(__name__) - +logger = logging.getLogger("alembic") Base: Type[Any] = declarative_base() @@ -95,6 +95,16 @@ class Slice(Base): schema_perm = sa.Column(sa.String(1000)) +ModelType = Union[Type[Query], Type[SavedQuery], Type[TabState], Type[TableSchema]] + +MODELS: list[tuple[ModelType, str]] = [ + (Query, "database_id"), + (SavedQuery, "db_id"), + (TabState, "database_id"), + (TableSchema, "database_id"), +] + + def get_known_schemas(database_name: str, session: Session) -> list[str]: """ Read all known schemas from the existing schema permissions. @@ -112,6 +122,234 @@ def get_known_schemas(database_name: str, session: Session) -> list[str]: return sorted({name[0][1:-1].split("].[")[-1] for name in names}) +def get_batch_size(session: Session) -> int: + max_sqlite_in = 999 + return max_sqlite_in if session.bind.dialect.name == "sqlite" else 1_000_000 + + +def print_processed_batch( + start_time: datetime, + offset: int, + total_rows: int, + model: ModelType, + batch_size: int, +) -> None: + """ + Print the progress of batch processing. + + This function logs the progress of processing a batch of rows from a model. + It calculates the elapsed time since the start of the batch processing and + logs the number of rows processed along with the percentage completion. + + Parameters: + start_time (datetime): The start time of the batch processing. + offset (int): The current offset in the batch processing. + total_rows (int): The total number of rows to process. + model (ModelType): The model being processed. + batch_size (int): The size of the batch being processed. + """ + elapsed_time = datetime.now() - start_time + elapsed_seconds = elapsed_time.total_seconds() + elapsed_formatted = f"{int(elapsed_seconds // 3600):02}:{int((elapsed_seconds % 3600) // 60):02}:{int(elapsed_seconds % 60):02}" + rows_processed = min(offset + batch_size, total_rows) + logger.info( + f"{elapsed_formatted} - {rows_processed:,} of {total_rows:,} {model.__tablename__} rows processed " + f"({(rows_processed / total_rows) * 100:.2f}%)" + ) + + +def update_catalog_column( + session: Session, database: Database, catalog: str, downgrade: bool = False +) -> None: + """ + Update the `catalog` column in the specified models to the given catalog. + + This function iterates over a list of models defined by MODELS and updates + the `catalog` columnto the specified catalog or None depending on the downgrade + parameter. The update is performed in batches to optimize performance and reduce + memory usage. + + Parameters: + session (Session): The SQLAlchemy session to use for database operations. + database (Database): The database instance containing the models to update. + catalog (Catalog): The new catalog value to set in the `catalog` column or + the default catalog if `downgrade` is True. + downgrade (bool): If True, the `catalog` column is set to None where the + catalog matches the specified catalog. + """ + start_time = datetime.now() + + logger.info(f"Updating {database.database_name} models to catalog {catalog}") + + for model, column in MODELS: + # Get the total number of rows that match the condition + total_rows = ( + session.query(sa.func.count(model.id)) + .filter(getattr(model, column) == database.id) + .filter(model.catalog == catalog if downgrade else True) + .scalar() + ) + + logger.info( + f"Total rows to be processed for {model.__tablename__}: {total_rows:,}" + ) + + batch_size = get_batch_size(session) + limit_value = min(batch_size, total_rows) + + # Update in batches using row numbers + for i in range(0, total_rows, batch_size): + subquery = ( + session.query(model.id) + .filter(getattr(model, column) == database.id) + .filter(model.catalog == catalog if downgrade else True) + .order_by(model.id) + .offset(i) + .limit(limit_value) + .subquery() + ) + + # SQLite does not support multiple-table criteria within UPDATE + if session.bind.dialect.name == "sqlite": + ids_to_update = [row.id for row in session.query(subquery.c.id).all()] + if ids_to_update: + session.execute( + sa.update(model) + .where(model.id.in_(ids_to_update)) + .values(catalog=None if downgrade else catalog) + .execution_options(synchronize_session=False) + ) + else: + session.execute( + sa.update(model) + .where(model.id == subquery.c.id) + .values(catalog=None if downgrade else catalog) + .execution_options(synchronize_session=False) + ) + + print_processed_batch(start_time, i, total_rows, model, batch_size) + + +def update_schema_catalog_perms( + session: Session, + database: Database, + catalog_perm: str | None, + catalog: str, + downgrade: bool = False, +) -> None: + """ + Update schema and catalog permissions for tables and charts in a given database. + + This function updates the `catalog`, `catalog_perm`, and `schema_perm` fields for + tables and charts associated with the specified database. If `downgrade` is True, + the `catalog` and `catalog_perm` fields are set to None, otherwise they are set + to the provided `catalog` and `catalog_perm` values. + + Args: + session (Session): The SQLAlchemy session to use for database operations. + database (Database): The database object whose tables and charts will be updated. + catalog_perm (str): The new catalog permission to set. + catalog (str): The new catalog to set. + downgrade (bool, optional): If True, reset the `catalog` and `catalog_perm` fields to None. + Defaults to False. + """ + # Mapping of table id to schema permission + mapping = {} + + for table in ( + session.query(SqlaTable) + .filter_by(database_id=database.id) + .filter_by(catalog=catalog if downgrade else None) + ): + schema_perm = security_manager.get_schema_perm( + database.database_name, + None if downgrade else catalog, + table.schema, + ) + table.catalog = None if downgrade else catalog + table.catalog_perm = catalog_perm + table.schema_perm = schema_perm + mapping[table.id] = schema_perm + + # Select all slices of type table that belong to the database + for chart in ( + session.query(Slice) + .join(SqlaTable, Slice.datasource_id == SqlaTable.id) + .join(Database, SqlaTable.database_id == Database.id) + .filter(Database.id == database.id) + .filter(Slice.datasource_type == "table") + ): + # We only care about tables that exist in the mapping + if mapping.get(chart.datasource_id) is not None: + chart.catalog_perm = catalog_perm + chart.schema_perm = mapping[chart.datasource_id] + + +def delete_models_non_default_catalog( + session: Session, database: Database, catalog: str +) -> None: + """ + Delete models that are not in the default catalog. + + This function iterates over a list of models defined by MODELS and deletes + the rows where the `catalog` column does not match the specified catalog. + + Parameters: + session (Session): The SQLAlchemy session to use for database operations. + database (Database): The database instance containing the models to delete. + catalog (Catalog): The catalog to use to filter the models to delete. + """ + start_time = datetime.now() + + logger.info(f"Deleting models not in the default catalog: {catalog}") + + for model, column in MODELS: + # Get the total number of rows that match the condition + total_rows = ( + session.query(sa.func.count(model.id)) + .filter(getattr(model, column) == database.id) + .filter(model.catalog != catalog) + .scalar() + ) + + logger.info( + f"Total rows to be processed for {model.__tablename__}: {total_rows:,}" + ) + + batch_size = get_batch_size(session) + limit_value = min(batch_size, total_rows) + + # Update in batches using row numbers + for i in range(0, total_rows, batch_size): + subquery = ( + session.query(model.id) + .filter(getattr(model, column) == database.id) + .filter(model.catalog != catalog) + .order_by(model.id) + .offset(i) + .limit(limit_value) + .subquery() + ) + + # SQLite does not support multiple-table criteria within DELETE + if session.bind.dialect.name == "sqlite": + ids_to_delete = [row.id for row in session.query(subquery.c.id).all()] + if ids_to_delete: + session.execute( + sa.delete(model) + .where(model.id.in_(ids_to_delete)) + .execution_options(synchronize_session=False) + ) + else: + session.execute( + sa.delete(model) + .where(model.id == subquery.c.id) + .execution_options(synchronize_session=False) + ) + + print_processed_batch(start_time, i, total_rows, model, batch_size) + + def upgrade_catalog_perms(engines: set[str] | None = None) -> None: """ Update models and permissions when catalogs are introduced in a DB engine spec. @@ -157,11 +395,13 @@ def upgrade_database_catalogs( """ Upgrade a given database to support the default catalog. """ - catalog_perm = security_manager.get_catalog_perm( + catalog_perm: str | None = security_manager.get_catalog_perm( database.database_name, default_catalog, ) - pvms: dict[str, tuple[str, ...]] = {catalog_perm: ("catalog_access",)} + pvms: dict[str, tuple[str, ...]] = ( + {catalog_perm: ("catalog_access",)} if catalog_perm else {} + ) # rename existing schema permissions to include the catalog, and also find any new # schemas @@ -170,39 +410,10 @@ def upgrade_database_catalogs( # update existing models that have a `catalog` column so it points to the default # catalog - models = [ - (Query, "database_id"), - (SavedQuery, "db_id"), - (TabState, "database_id"), - (TableSchema, "database_id"), - ] - for model, column in models: - for instance in session.query(model).filter( - getattr(model, column) == database.id - ): - instance.catalog = default_catalog + update_catalog_column(session, database, default_catalog, False) # update `schema_perm` and `catalog_perm` for tables and charts - for table in session.query(SqlaTable).filter_by( - database_id=database.id, - catalog=None, - ): - schema_perm = security_manager.get_schema_perm( - database.database_name, - default_catalog, - table.schema, - ) - - table.catalog = default_catalog - table.catalog_perm = catalog_perm - table.schema_perm = schema_perm - - for chart in session.query(Slice).filter_by( - datasource_id=table.id, - datasource_type="table", - ): - chart.catalog_perm = catalog_perm - chart.schema_perm = schema_perm + update_schema_catalog_perms(session, database, catalog_perm, default_catalog, False) # add any new catalogs discovered and their schemas new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session) @@ -233,13 +444,15 @@ def add_non_default_catalogs( # edited. return {} - pvms = {} + pvms: dict[str, tuple[str]] = {} for catalog in catalogs: - perm = security_manager.get_catalog_perm(database.database_name, catalog) - pvms[perm] = ("catalog_access",) - - new_schema_pvms = create_schema_perms(database, catalog, session) - pvms.update(new_schema_pvms) + perm: str | None = security_manager.get_catalog_perm( + database.database_name, catalog + ) + if perm: + pvms[perm] = ("catalog_access",) + new_schema_pvms = create_schema_perms(database, catalog) + pvms.update(new_schema_pvms) return pvms @@ -266,12 +479,12 @@ def upgrade_schema_perms( perms = {} for schema in schemas: - current_perm = security_manager.get_schema_perm( + current_perm: str | None = security_manager.get_schema_perm( database.database_name, None, schema, ) - new_perm = security_manager.get_schema_perm( + new_perm: str | None = security_manager.get_schema_perm( database.database_name, default_catalog, schema, @@ -283,7 +496,7 @@ def upgrade_schema_perms( .one_or_none() ): existing_pvm.name = new_perm - else: + elif new_perm: # new schema discovered, need to create a new permission perms[new_perm] = ("schema_access",) @@ -293,7 +506,6 @@ def upgrade_schema_perms( def create_schema_perms( database: Database, catalog: str, - session: Session, ) -> dict[str, tuple[str]]: """ Create schema permissions for a given catalog. @@ -307,12 +519,14 @@ def create_schema_perms( return {} return { - security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ): ("schema_access",) + perm: ("schema_access",) for schema in schemas + if ( + perm := security_manager.get_schema_perm( + database.database_name, catalog, schema + ) + ) + is not None } @@ -374,49 +588,13 @@ def downgrade_database_catalogs( # permissions associated with other catalogs downgrade_schema_perms(database, default_catalog, session) - # update existing models - models = [ - (Query, "database_id"), - (SavedQuery, "db_id"), - (TabState, "database_id"), - (TableSchema, "database_id"), - ] - for model, column in models: - for instance in session.query(model).filter( - getattr(model, column) == database.id, - model.catalog == default_catalog, # type: ignore - ): - instance.catalog = None + update_catalog_column(session, database, default_catalog, True) - # update `schema_perm` for tables and charts - for table in session.query(SqlaTable).filter_by( - database_id=database.id, - catalog=default_catalog, - ): - schema_perm = security_manager.get_schema_perm( - database.database_name, - None, - table.schema, - ) - - table.catalog = None - table.catalog_perm = None - table.schema_perm = schema_perm - - for chart in session.query(Slice).filter_by( - datasource_id=table.id, - datasource_type="table", - ): - chart.catalog_perm = None - chart.schema_perm = schema_perm + # update `schema_perm` and `catalog_perm` for tables and charts + update_schema_catalog_perms(session, database, None, default_catalog, True) # delete models referencing non-default catalogs - for model, column in models: - for instance in session.query(model).filter( - getattr(model, column) == database.id, - model.catalog != default_catalog, # type: ignore - ): - session.delete(instance) + delete_models_non_default_catalog(session, database, default_catalog) # delete datasets and any associated permissions for table in session.query(SqlaTable).filter(