Skip to content

Commit

Permalink
feat: Adds helper functions for migrations (apache#31303)
Browse files Browse the repository at this point in the history
  • Loading branch information
luizotavio32 authored Dec 11, 2024
1 parent fd57fce commit 423a0fe
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 70 deletions.
22 changes: 0 additions & 22 deletions superset/migrations/shared/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
from dataclasses import dataclass

from alembic import op
from sqlalchemy.dialects.sqlite.base import SQLiteDialect # noqa: E402
from sqlalchemy.engine.reflection import Inspector

from superset.migrations.shared.utils import has_table
from superset.utils.core import generic_find_fk_constraint_name


Expand Down Expand Up @@ -73,23 +71,3 @@ def redefine(
ondelete=on_delete,
onupdate=on_update,
)


def drop_fks_for_table(table_name: str) -> None:
"""
Drop all foreign key constraints for a table if it exist and the database
is not sqlite.
:param table_name: The table name to drop foreign key constraints for
"""
connection = op.get_bind()
inspector = Inspector.from_engine(connection)

if isinstance(connection.dialect, SQLiteDialect):
return # sqlite doesn't like constraints

if has_table(table_name):
foreign_keys = inspector.get_foreign_keys(table_name)

for fk in foreign_keys:
op.drop_constraint(fk["name"], table_name, type_="foreignkey")
225 changes: 213 additions & 12 deletions superset/migrations/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,25 @@
from typing import Any, Callable, Optional, Union
from uuid import uuid4

import sqlalchemy as sa
from alembic import op
from sqlalchemy import inspect
from sqlalchemy import Column, inspect
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.dialects.sqlite.base import SQLiteDialect # noqa: E402
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Query, Session
from sqlalchemy.sql.schema import SchemaItem

from superset.utils import json

logger = logging.getLogger(__name__)
GREEN = "\033[32m"
RESET = "\033[0m"
YELLOW = "\033[33m"
RED = "\033[31m"
LRED = "\033[91m"

logger = logging.getLogger("alembic")

DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000))

Expand Down Expand Up @@ -185,15 +193,208 @@ def has_table(table_name: str) -> bool:
return table_exists


def add_column_if_not_exists(table_name: str, column: sa.Column) -> None:
def drop_fks_for_table(table_name: str) -> None:
"""
Drop all foreign key constraints for a table if it exist and the database
is not sqlite.
:param table_name: The table name to drop foreign key constraints for
"""
connection = op.get_bind()
inspector = Inspector.from_engine(connection)

if isinstance(connection.dialect, SQLiteDialect):
return # sqlite doesn't like constraints

if has_table(table_name):
foreign_keys = inspector.get_foreign_keys(table_name)
for fk in foreign_keys:
logger.info(
f"Dropping foreign key {GREEN}{fk['name']}{RESET} from table {GREEN}{table_name}{RESET}..."
)
op.drop_constraint(fk["name"], table_name, type_="foreignkey")


def create_table(table_name: str, *columns: SchemaItem) -> None:
"""
Creates a database table with the specified name and columns.
This function checks if a table with the given name already exists in the database.
If the table already exists, it logs an informational.
Otherwise, it proceeds to create a new table using the provided name and schema columns.
:param table_name: The name of the table to be created.
:param columns: A variable number of arguments representing the schema just like when calling alembic's method create_table()
"""

if has_table(table_name=table_name):
logger.info(f"Table {LRED}{table_name}{RESET} already exists. Skipping...")
return

logger.info(f"Creating table {GREEN}{table_name}{RESET}...")
op.create_table(table_name, *columns)
logger.info(f"Table {GREEN}{table_name}{RESET} created.")


def drop_table(table_name: str) -> None:
"""
Adds a column to a table if it does not already exist.
Drops a database table with the specified name.
:param table_name: Name of the table.
:param column: SQLAlchemy Column object.
This function checks if a table with the given name exists in the database.
If the table does not exist, it logs an informational message and skips the dropping process.
If the table exists, it first attempts to drop all foreign key constraints associated with the table
(handled by `drop_fks_for_table`) and then proceeds to drop the table.
:param table_name: The name of the table to be dropped.
"""
if not table_has_column(table_name, column.name):
print(f"Adding column '{column.name}' to table '{table_name}'.\n")
op.add_column(table_name, column)
else:
print(f"Column '{column.name}' already exists in table '{table_name}'.\n")

if not has_table(table_name=table_name):
logger.info(f"Table {GREEN}{table_name}{RESET} doesn't exist. Skipping...")
return

logger.info(f"Dropping table {GREEN}{table_name}{RESET}...")
drop_fks_for_table(table_name)
op.drop_table(table_name=table_name)
logger.info(f"Table {GREEN}{table_name}{RESET} dropped.")


def batch_operation(
callable: Callable[[int, int], None], count: int, batch_size: int
) -> None:
"""
Executes an operation by dividing a task into smaller batches and tracking progress.
This function is designed to process a large number of items in smaller batches. It takes a callable
that performs the operation on each batch. The function logs the progress of the operation as it processes
through the batches.
If count is set to 0 or lower, it logs an informational message and skips the batch process.
:param callable: A callable function that takes two integer arguments:
the start index and the end index of the current batch.
:param count: The total number of items to process.
:param batch_size: The number of items to process in each batch.
"""
if count <= 0:
logger.info(
f"No records to process in batch {LRED}(count <= 0){RESET} for callable {LRED}other_callable_example{RESET}. Skipping..."
)
return
for offset in range(0, count, batch_size):
percentage = (offset / count) * 100 if count else 0
logger.info(f"Progress: {offset:,}/{count:,} ({percentage:.2f}%)")
callable(offset, min(offset + batch_size, count))

logger.info(f"Progress: {count:,}/{count:,} (100%)")
logger.info(
f"End: {GREEN}{callable.__name__}{RESET} batch operation {GREEN}succesfully{RESET} executed."
)


def add_columns(table_name: str, *columns: Column) -> None:
"""
Adds new columns to an existing database table.
If a column already exists, it logs an informational message and skips the adding process.
Otherwise, it proceeds to add the new column to the table.
The operation is performed using Alembic's batch_alter_table.
:param table_name: The name of the table to which the columns will be added.
:param columns: A list of SQLAlchemy Column objects that define the name, type, and other attributes of the columns to be added.
"""

cols_to_add = []
for col in columns:
if table_has_column(table_name=table_name, column_name=col.name):
logger.info(
f"Column {LRED}{col.name}{RESET} already present on table {LRED}{table_name}{RESET}. Skipping..."
)
else:
cols_to_add.append(col)

with op.batch_alter_table(table_name) as batch_op:
for col in cols_to_add:
logger.info(
f"Adding column {GREEN}{col.name}{RESET} to table {GREEN}{table_name}{RESET}..."
)
batch_op.add_column(col)


def drop_columns(table_name: str, *columns: str) -> None:
"""
Drops specified columns from an existing database table.
If a column does not exist, it logs an informational message and skips the dropping process.
Otherwise, it proceeds to remove the column from the table.
The operation is performed using Alembic's batch_alter_table.
:param table_name: The name of the table from which the columns will be removed.
:param columns: A list of column names to be dropped.
"""

cols_to_drop = []
for col in columns:
if not table_has_column(table_name=table_name, column_name=col):
logger.info(
f"Column {LRED}{col}{RESET} is not present on table {LRED}{table_name}{RESET}. Skipping..."
)
else:
cols_to_drop.append(col)

with op.batch_alter_table(table_name) as batch_op:
for col in cols_to_drop:
logger.info(
f"Dropping column {GREEN}{col}{RESET} from table {GREEN}{table_name}{RESET}..."
)
batch_op.drop_column(col)


def create_index(table_name: str, index_name: str, *columns: str) -> None:
"""
Creates an index on specified columns of an existing database table.
If the index already exists, it logs an informational message and skips the creation process.
Otherwise, it proceeds to create a new index with the specified name on the given columns of the table.
:param table_name: The name of the table on which the index will be created.
:param index_name: The name of the index to be created.
:param columns: A list column names where the index will be created
"""

if table_has_index(table=table_name, index=index_name):
logger.info(
f"Table {LRED}{table_name}{RESET} already has index {LRED}{index_name}{RESET}. Skipping..."
)
return

logger.info(
f"Creating index {GREEN}{index_name}{RESET} on table {GREEN}{table_name}{RESET}"
)

op.create_index(table_name=table_name, index_name=index_name, columns=columns)


def drop_index(table_name: str, index_name: str) -> None:
"""
Drops an index from an existing database table.
If the index does not exists, it logs an informational message and skips the dropping process.
Otherwise, it proceeds with the removal operation.
:param table_name: The name of the table from which the index will be dropped.
:param index_name: The name of the index to be dropped.
"""

if not table_has_index(table=table_name, index=index_name):
logger.info(
f"Table {LRED}{table_name}{RESET} doesn't have index {LRED}{index_name}{RESET}. Skipping..."
)
return

logger.info(
f"Dropping index {GREEN}{index_name}{RESET} from table {GREEN}{table_name}{RESET}..."
)

op.drop_index(table_name=table_name, index_name=index_name)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from alembic import op # noqa: E402
from sqlalchemy_utils import EncryptedType # noqa: E402

from superset.migrations.shared.constraints import drop_fks_for_table # noqa: E402
from superset.migrations.shared.utils import drop_fks_for_table # noqa: E402


def upgrade():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,20 @@
"""

import sqlalchemy as sa
from alembic import op

from superset.migrations.shared.utils import add_column_if_not_exists
from superset.migrations.shared.utils import add_columns, drop_columns

# revision identifiers, used by Alembic.
revision = "c22cb5c2e546"
down_revision = "678eefb4ab44"


def upgrade():
add_column_if_not_exists(
add_columns(
"user_attribute",
sa.Column("avatar_url", sa.String(length=100), nullable=True),
)


def downgrade():
op.drop_column("user_attribute", "avatar_url")
drop_columns("user_attribute", "avatar_url")
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
"""

import sqlalchemy as sa
from alembic import op

from superset.migrations.shared.utils import add_column_if_not_exists
from superset.migrations.shared.utils import add_columns, drop_columns

# revision identifiers, used by Alembic.
revision = "5f57af97bc3f"
Expand All @@ -36,12 +35,9 @@

def upgrade():
for table in tables:
add_column_if_not_exists(
table,
sa.Column("catalog", sa.String(length=256), nullable=True),
)
add_columns(table, sa.Column("catalog", sa.String(length=256), nullable=True))


def downgrade():
for table in reversed(tables):
op.drop_column(table, "catalog")
drop_columns(table, "catalog")
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,29 @@
"""

import sqlalchemy as sa
from alembic import op

from superset.migrations.shared.catalogs import (
downgrade_catalog_perms,
upgrade_catalog_perms,
)
from superset.migrations.shared.utils import add_column_if_not_exists
from superset.migrations.shared.utils import add_columns, drop_columns

# revision identifiers, used by Alembic.
revision = "58d051681a3b"
down_revision = "4a33124c18ad"


def upgrade():
add_column_if_not_exists(
"tables",
sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
add_columns(
"tables", sa.Column("catalog_perm", sa.String(length=1000), nullable=True)
)
add_column_if_not_exists(
"slices",
sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
add_columns(
"slices", sa.Column("catalog_perm", sa.String(length=1000), nullable=True)
)
upgrade_catalog_perms(engines={"postgresql"})


def downgrade():
downgrade_catalog_perms(engines={"postgresql"})
op.drop_column("slices", "catalog_perm")
op.drop_column("tables", "catalog_perm")
drop_columns("slices", "catalog_perm")
drop_columns("tables", "catalog_perm")
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
import sqlalchemy as sa
from alembic import op

from superset.migrations.shared.constraints import drop_fks_for_table
from superset.migrations.shared.utils import has_table
from superset.migrations.shared.utils import drop_fks_for_table, has_table

# revision identifiers, used by Alembic.
revision = "02f4f7811799"
Expand Down
Loading

0 comments on commit 423a0fe

Please sign in to comment.