Skip to content

Commit

Permalink
fix(trino): handle missing db in migration (apache#29997)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro authored Aug 22, 2024
1 parent 5906890 commit 17eecb1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
5 changes: 4 additions & 1 deletion superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,13 @@ def epoch_to_dttm(cls) -> str:
return "from_unixtime({col})"

@classmethod
def get_default_catalog(cls, database: "Database") -> str | None:
def get_default_catalog(cls, database: Database) -> str | None:
"""
Return the default catalog.
"""
if database.url_object.database is None:
return None

return database.url_object.database.split("/")[0]

@classmethod
Expand Down
22 changes: 13 additions & 9 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from __future__ import annotations

import copy
from collections import namedtuple
from datetime import datetime
Expand Down Expand Up @@ -719,7 +721,15 @@ def test_adjust_engine_params_catalog_only() -> None:
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema"


def test_get_default_catalog() -> None:
@pytest.mark.parametrize(
"sqlalchemy_uri,result",
[
("trino://user:pass@localhost:8080/system", "system"),
("trino://user:pass@localhost:8080/system/default", "system"),
("trino://trino@localhost:8081", None),
],
)
def test_get_default_catalog(sqlalchemy_uri: str, result: str | None) -> None:
"""
Test the ``get_default_catalog`` method.
"""
Expand All @@ -728,15 +738,9 @@ def test_get_default_catalog() -> None:

database = Database(
database_name="my_db",
sqlalchemy_uri="trino://user:pass@localhost:8080/system",
)
assert TrinoEngineSpec.get_default_catalog(database) == "system"

database = Database(
database_name="my_db",
sqlalchemy_uri="trino://user:pass@localhost:8080/system/default",
sqlalchemy_uri=sqlalchemy_uri,
)
assert TrinoEngineSpec.get_default_catalog(database) == "system"
assert TrinoEngineSpec.get_default_catalog(database) == result


@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")
Expand Down

0 comments on commit 17eecb1

Please sign in to comment.