Skip to content

Commit

Permalink
fix: make catalog migration lenient (apache#29549)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Jul 11, 2024
1 parent 33b934c commit d535f3f
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 38 deletions.
117 changes: 79 additions & 38 deletions superset/migrations/shared/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

from superset import db, security_manager
from superset.daos.database import DatabaseDAO
Expand Down Expand Up @@ -86,6 +87,24 @@ class Slice(Base):
schema_perm = sa.Column(sa.String(1000))


def get_schemas(database_name: str) -> list[str]:
"""
Read all known schemas from the schema permissions.
"""
query = f"""
SELECT
avm.name
FROM ab_view_menu avm
JOIN ab_permission_view apv ON avm.id = apv.view_menu_id
JOIN ab_permission ap ON apv.permission_id = ap.id
WHERE
avm.name LIKE '[{database_name}]%' AND
ap.name = 'schema_access';
"""
# [PostgreSQL].[postgres].[public] => public
return sorted({row[0].split(".")[-1][1:-1] for row in op.execute(query)})


def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Update models when catalogs are introduced in a DB engine spec.
Expand Down Expand Up @@ -116,25 +135,7 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
)
add_pvms(session, {perm: ("catalog_access",)})

# update schema_perms
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
perm = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
upgrade_schema_perms(database, catalog, session)

# update existing models
models = [
Expand Down Expand Up @@ -166,6 +167,35 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
session.commit()


def upgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
"""
Rename existing schema permissions to include the catalog.
"""
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
try:
schemas = database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
)
except Exception: # pylint: disable=broad-except
schemas = get_schemas(database.database_name)

for schema in schemas:
perm = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)


def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Reverse the process of `upgrade_catalog_perms`.
Expand All @@ -183,25 +213,7 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
if catalog is None:
continue

# update schema_perms
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
perm = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
downgrade_schema_perms(database, catalog, session)

# update existing models
models = [
Expand Down Expand Up @@ -231,3 +243,32 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
chart.schema_perm = schema_perm

session.commit()


def downgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
"""
Rename existing schema permissions to omit the catalog.
"""
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
try:
schemas = database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
)
except Exception: # pylint: disable=broad-except
schemas = get_schemas(database.database_name)

for schema in schemas:
perm = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
125 changes: 125 additions & 0 deletions tests/unit_tests/migrations/shared/catalogs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,128 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
("[my_db].[public]",),
("[my_db].[db]",),
]


def test_upgrade_catalog_perms_graceful(
mocker: MockerFixture,
session: Session,
) -> None:
"""
Test the `upgrade_catalog_perms` function when it fails to connect to the DB.
During the migration we try to connect to the analytical database to get the list of
schemas. This should fail gracefully and not raise an exception, since the database
could be offline, and the permissions can be generated later then the admin enables
catalog browsing on the database (permissions are always synced on a DB update, see
`UpdateDatabaseCommand`).
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState

engine = session.get_bind()
Database.metadata.create_all(engine)

mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
db.Session.return_value = session

mocker.patch.object(
Database,
"get_all_schema_names",
side_effect=Exception("Failed to connect to the database"),
)
mocker.patch("superset.migrations.shared.catalogs.op", session)

database = Database(
database_name="my_db",
sqlalchemy_uri="postgresql://localhost/db",
)
dataset = SqlaTable(
table_name="my_table",
database=database,
catalog=None,
schema="public",
schema_perm="[my_db].[public]",
)
session.add(dataset)
session.commit()

chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
)
query = Query(
client_id="foo",
database=database,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()

# before migration
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
]

upgrade_catalog_perms()

# after migration
assert dataset.catalog == "db"
assert query.catalog == "db"
assert saved_query.catalog == "db"
assert tab_state.catalog == "db"
assert table_schema.catalog == "db"
assert dataset.schema_perm == "[my_db].[db].[public]"
assert chart.schema_perm == "[my_db].[db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[db].[public]",),
("[my_db].[db]",),
]

downgrade_catalog_perms()

# revert
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
("[my_db].[db]",),
]

0 comments on commit d535f3f

Please sign in to comment.