diff --git a/superset/config.py b/superset/config.py index 0234c0deb230b..f0ba25a3eb088 100644 --- a/superset/config.py +++ b/superset/config.py @@ -68,6 +68,7 @@ if TYPE_CHECKING: from flask_appbuilder.security.sqla import models + from sqlglot import Dialect, Dialects from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database @@ -249,6 +250,10 @@ def _try_json_readsha(filepath: str, length: int) -> str | None: SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER = ( # pylint: disable=invalid-name SQLAlchemyUtilsAdapter ) + +# Extends the default SQLGlot dialects with additional dialects +SQLGLOT_DIALECTS_EXTENSIONS: map[str, Dialects | type[Dialect]] = {} + # The limit of queries fetched for query search QUERY_SEARCH_LIMIT = 1000 diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 10686b872ff41..1720c87af48e2 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -54,6 +54,7 @@ talisman, ) from superset.security import SupersetSecurityManager +from superset.sql.parse import SQLGLOT_DIALECTS from superset.superset_typing import FlaskResponse from superset.tags.core import register_sqla_event_listeners from superset.utils.core import is_test, pessimistic_connection_handling @@ -484,6 +485,7 @@ def init_app(self) -> None: self.configure_middlewares() self.configure_cache() self.set_db_default_isolation() + self.configure_sqlglot_dialects() with self.superset_app.app_context(): self.init_app_in_ctx() @@ -544,6 +546,9 @@ def configure_cache(self) -> None: def configure_feature_flags(self) -> None: feature_flag_manager.init_app(self.superset_app) + def configure_sqlglot_dialects(self) -> None: + SQLGLOT_DIALECTS.update(self.config["SQLGLOT_DIALECTS_EXTENSIONS"]) + @transaction() def configure_fab(self) -> None: if self.config["SILENCE_FAB"]: diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 91b68126356a2..1581b0c6e79e6 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -28,7 +28,6 @@ from flask_babel import gettext as __ from jinja2 import nodes from sqlalchemy import and_ -from sqlglot.dialects.dialect import Dialects from sqlparse import keywords from sqlparse.lexer import Lexer from sqlparse.sql import ( @@ -61,7 +60,12 @@ SupersetParseError, SupersetSecurityException, ) -from superset.sql.parse import extract_tables_from_statement, SQLScript, Table +from superset.sql.parse import ( + extract_tables_from_statement, + SQLGLOT_DIALECTS, + SQLScript, + Table, +) from superset.utils.backports import StrEnum try: @@ -88,61 +92,6 @@ lex.set_SQL_REGEX(sqlparser_sql_regex) -# mapping between DB engine specs and sqlglot dialects -SQLGLOT_DIALECTS = { - "ascend": Dialects.HIVE, - "awsathena": Dialects.PRESTO, - "bigquery": Dialects.BIGQUERY, - "clickhouse": Dialects.CLICKHOUSE, - "clickhousedb": Dialects.CLICKHOUSE, - "cockroachdb": Dialects.POSTGRES, - "couchbase": Dialects.MYSQL, - # "crate": ??? - # "databend": ??? - "databricks": Dialects.DATABRICKS, - # "db2": ??? - # "dremio": ??? - "drill": Dialects.DRILL, - # "druid": ??? - "duckdb": Dialects.DUCKDB, - # "dynamodb": ??? - # "elasticsearch": ??? - # "exa": ??? - # "firebird": ??? - # "firebolt": ??? - "gsheets": Dialects.SQLITE, - "hana": Dialects.POSTGRES, - "hive": Dialects.HIVE, - # "ibmi": ??? - # "impala": ??? - # "kustokql": ??? - # "kylin": ??? - "mssql": Dialects.TSQL, - "mysql": Dialects.MYSQL, - "netezza": Dialects.POSTGRES, - # "ocient": ??? - # "odelasticsearch": ??? - "oracle": Dialects.ORACLE, - # "pinot": ??? - "postgresql": Dialects.POSTGRES, - "presto": Dialects.PRESTO, - "pydoris": Dialects.DORIS, - "redshift": Dialects.REDSHIFT, - # "risingwave": ??? - # "rockset": ??? - "shillelagh": Dialects.SQLITE, - "snowflake": Dialects.SNOWFLAKE, - # "solr": ??? - "spark": Dialects.SPARK, - "sqlite": Dialects.SQLITE, - "starrocks": Dialects.STARROCKS, - "superset": Dialects.SQLITE, - "teradatasql": Dialects.TERADATA, - "trino": Dialects.TRINO, - "vertica": Dialects.POSTGRES, -} - - class CtasMethod(StrEnum): TABLE = "TABLE" VIEW = "VIEW" diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 6c1e5791277b3..ae5ebf89a8b96 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -18,6 +18,7 @@ import pytest +from sqlglot import Dialects from superset.exceptions import SupersetParseError from superset.sql.parse import ( @@ -932,3 +933,15 @@ def test_get_settings() -> None: SELECT * FROM some_table; """ assert SQLScript(sql, "postgresql").get_settings() == {"search_path": "bar"} + + +@pytest.mark.parametrize( + "app", + [{"SQLGLOT_DIALECTS_EXTENSIONS": {"custom": Dialects.MYSQL}}], + indirect=True, +) +def test_custom_dialect(app: None) -> None: + """ + Test that custom dialects are loaded correctly. + """ + assert SQLGLOT_DIALECTS.get("custom") == Dialects.MYSQL