From 8d73ab9955ce913db5b02836f12de8bd75d162ec Mon Sep 17 00:00:00 2001 From: Rob Moore Date: Mon, 20 Nov 2023 17:59:10 +0000 Subject: [PATCH] feat(sqllab): TRINO_EXPAND_ROWS: expand columns from ROWs (#25809) --- .../databases/DatabaseModal/ExtraOptions.tsx | 18 ++- .../databases/DatabaseModal/index.test.tsx | 17 ++- .../databases/DatabaseModal/index.tsx | 12 ++ .../src/features/databases/types.ts | 3 + superset/db_engine_specs/base.py | 19 ++- superset/db_engine_specs/druid.py | 11 -- superset/db_engine_specs/hive.py | 8 +- superset/db_engine_specs/presto.py | 7 +- superset/db_engine_specs/trino.py | 62 +++++++++ superset/models/core.py | 10 +- superset/superset_typing.py | 2 + .../unit_tests/db_engine_specs/test_trino.py | 122 ++++++++++++++++++ 12 files changed, 268 insertions(+), 23 deletions(-) diff --git a/superset-frontend/src/features/databases/DatabaseModal/ExtraOptions.tsx b/superset-frontend/src/features/databases/DatabaseModal/ExtraOptions.tsx index 55c3875f98a62..45706da5868da 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/ExtraOptions.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/ExtraOptions.tsx @@ -202,7 +202,7 @@ const ExtraOptions = ({ /> - +
+ +
+ + +
+
diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx index bcd9fbe694706..ba443e0099457 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx @@ -674,7 +674,7 @@ describe('DatabaseModal', () => { const exposeInSQLLabCheckbox = screen.getByRole('checkbox', { name: /expose database in sql lab/i, }); - // This is both the checkbox and it's respective SVG + // This is both the checkbox and its respective SVG // const exposeInSQLLabCheckboxSVG = checkboxOffSVGs[0].parentElement; const exposeInSQLLabText = screen.getByText( /expose database in sql lab/i, @@ -721,6 +721,13 @@ describe('DatabaseModal', () => { /Disable SQL Lab data preview queries/i, ); + const enableRowExpansionCheckbox = screen.getByRole('checkbox', { + name: /enable row expansion in schemas/i, + }); + const enableRowExpansionText = screen.getByText( + /enable row expansion in schemas/i, + ); + // ---------- Assertions ---------- const visibleComponents = [ closeButton, @@ -737,6 +744,7 @@ describe('DatabaseModal', () => { checkboxOffSVGs[2], checkboxOffSVGs[3], checkboxOffSVGs[4], + checkboxOffSVGs[5], tooltipIcons[0], tooltipIcons[1], tooltipIcons[2], @@ -744,6 +752,7 @@ describe('DatabaseModal', () => { tooltipIcons[4], tooltipIcons[5], tooltipIcons[6], + tooltipIcons[7], exposeInSQLLabText, allowCTASText, allowCVASText, @@ -754,6 +763,7 @@ describe('DatabaseModal', () => { enableQueryCostEstimationText, allowDbExplorationText, disableSQLLabDataPreviewQueriesText, + enableRowExpansionText, ]; // These components exist in the DOM but are not visible const invisibleComponents = [ @@ -764,6 +774,7 @@ describe('DatabaseModal', () => { enableQueryCostEstimationCheckbox, allowDbExplorationCheckbox, disableSQLLabDataPreviewQueriesCheckbox, + enableRowExpansionCheckbox, ]; visibleComponents.forEach(component => { expect(component).toBeVisible(); @@ -771,8 +782,8 @@ describe('DatabaseModal', () => { invisibleComponents.forEach(component => { expect(component).not.toBeVisible(); }); - expect(checkboxOffSVGs).toHaveLength(5); - expect(tooltipIcons).toHaveLength(7); + expect(checkboxOffSVGs).toHaveLength(6); + expect(tooltipIcons).toHaveLength(8); }); test('renders the "Advanced" - PERFORMANCE tab correctly', async () => { diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.tsx index 0c1ac56369692..18c93f2bf462f 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx @@ -307,6 +307,18 @@ export function dbReducer( }), }; } + if (action.payload.name === 'expand_rows') { + return { + ...trimmedState, + extra: JSON.stringify({ + ...extraJson, + schema_options: { + ...extraJson?.schema_options, + [action.payload.name]: !!action.payload.value, + }, + }), + }; + } return { ...trimmedState, extra: JSON.stringify({ diff --git a/superset-frontend/src/features/databases/types.ts b/superset-frontend/src/features/databases/types.ts index e138a9143669e..1d616fa13c053 100644 --- a/superset-frontend/src/features/databases/types.ts +++ b/superset-frontend/src/features/databases/types.ts @@ -226,5 +226,8 @@ export interface ExtraJson { table_cache_timeout?: number; // in Performance }; // No field, holds schema and table timeout schemas_allowed_for_file_upload?: string[]; // in Security + schema_options?: { + expand_rows?: boolean; + }; version?: string; } diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 6bce03d931710..9894232ab1bdb 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -51,7 +51,7 @@ from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Session -from sqlalchemy.sql import quoted_name, text +from sqlalchemy.sql import literal_column, quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause from sqlalchemy.types import TypeEngine from sqlparse.tokens import CTE @@ -1322,8 +1322,12 @@ def get_table_comment( return comment @classmethod - def get_columns( - cls, inspector: Inspector, table_name: str, schema: str | None + def get_columns( # pylint: disable=unused-argument + cls, + inspector: Inspector, + table_name: str, + schema: str | None, + options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ Get all columns from a given schema and table @@ -1331,6 +1335,8 @@ def get_columns( :param inspector: SqlAlchemy Inspector instance :param table_name: Table name :param schema: Schema name. If omitted, uses default schema for database + :param options: Extra options to customise the display of columns in + some databases :return: All columns in table """ return convert_inspector_columns( @@ -1382,7 +1388,12 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen @classmethod def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: - return [column(c["column_name"]) for c in cols] + return [ + literal_column(query_as) + if (query_as := c.get("query_as")) + else column(c["column_name"]) + for c in cols + ] @classmethod def select_star( # pylint: disable=too-many-arguments,too-many-locals diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 9bba3a727438b..7cd85ec924cf9 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -23,14 +23,12 @@ from typing import Any, TYPE_CHECKING from sqlalchemy import types -from sqlalchemy.engine.reflection import Inspector from superset import is_feature_enabled from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError from superset.exceptions import SupersetException -from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils if TYPE_CHECKING: @@ -130,15 +128,6 @@ def epoch_ms_to_dttm(cls) -> str: """ return "MILLIS_TO_TIMESTAMP({col})" - @classmethod - def get_columns( - cls, inspector: Inspector, table_name: str, schema: str | None - ) -> list[ResultSetColumnType]: - """ - Update the Druid type map. - """ - return super().get_columns(inspector, table_name, schema) - @classmethod def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 4a881e15b276b..bd303f928d625 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -410,9 +410,13 @@ def handle_cursor( # pylint: disable=too-many-locals @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: str | None + cls, + inspector: Inspector, + table_name: str, + schema: str | None, + options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: - return BaseEngineSpec.get_columns(inspector, table_name, schema) + return BaseEngineSpec.get_columns(inspector, table_name, schema, options) @classmethod def where_latest_partition( # pylint: disable=too-many-arguments diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 8afa82d9b55d9..27e86a7980875 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -981,7 +981,11 @@ def _show_columns( @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: str | None + cls, + inspector: Inspector, + table_name: str, + schema: str | None, + options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ Get columns from a Presto data source. This includes handling row and @@ -989,6 +993,7 @@ def get_columns( :param inspector: object that performs database schema inspection :param table_name: table name :param schema: schema name + :param options: Extra configuration options, not used by this backend :return: a list of results that contain column info (i.e. column name and data type) """ diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 125a96ab82301..d1c8e20bea9ee 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -24,8 +24,10 @@ import simplejson as json from flask import current_app +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session +from trino.sqlalchemy import datatype from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.databases.utils import make_url_safe @@ -33,6 +35,7 @@ from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError from superset.db_engine_specs.presto import PrestoBaseEngineSpec from superset.models.sql_lab import Query +from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils if TYPE_CHECKING: @@ -331,3 +334,62 @@ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: return { requests_exceptions.ConnectionError: SupersetDBAPIConnectionError, } + + @classmethod + def _expand_columns(cls, col: ResultSetColumnType) -> list[ResultSetColumnType]: + """ + Expand the given column out to one or more columns by analysing their types, + descending into ROWS and expanding out their inner fields recursively. + + We can only navigate named fields in ROWs in this way, so we can't expand out + MAP or ARRAY types, nor fields in ROWs which have no name (in fact the trino + library doesn't correctly parse unnamed fields in ROWs). We won't be able to + expand ROWs which are nested underneath any of those types, either. + + Expanded columns are named foo.bar.baz and we provide a query_as property to + instruct the base engine spec how to correctly query them: instead of quoting + the whole string they have to be quoted like "foo"."bar"."baz" and we then + alias them to the full dotted string for ease of reference. + """ + cols = [col] + col_type = col.get("type") + + if not isinstance(col_type, datatype.ROW): + return cols + + for inner_name, inner_type in col_type.attr_types: + outer_name = col["name"] + name = ".".join([outer_name, inner_name]) + query_name = ".".join([f'"{piece}"' for piece in name.split(".")]) + column_spec = cls.get_column_spec(str(inner_type)) + is_dttm = column_spec.is_dttm if column_spec else False + + inner_col = ResultSetColumnType( + name=name, + column_name=name, + type=inner_type, + is_dttm=is_dttm, + query_as=f'{query_name} AS "{name}"', + ) + cols.extend(cls._expand_columns(inner_col)) + + return cols + + @classmethod + def get_columns( + cls, + inspector: Inspector, + table_name: str, + schema: str | None, + options: dict[str, Any] | None = None, + ) -> list[ResultSetColumnType]: + """ + If the "expand_rows" feature is enabled on the database via + "schema_options", expand the schema definition out to show all + subfields of nested ROWs as their appropriate dotted paths. + """ + base_cols = super().get_columns(inspector, table_name, schema, options) + if not (options or {}).get("expand_rows"): + return base_cols + + return [col for base_col in base_cols for col in cls._expand_columns(base_col)] diff --git a/superset/models/core.py b/superset/models/core.py index 6fa394de06cad..d2b38ea806513 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -237,6 +237,11 @@ def disable_data_preview(self) -> bool: # this will prevent any 'trash value' strings from going through return self.get_extra().get("disable_data_preview", False) is True + @property + def schema_options(self) -> dict[str, Any]: + """Additional schema display config for engines with complex schemas""" + return self.get_extra().get("schema_options", {}) + @property def data(self) -> dict[str, Any]: return { @@ -248,6 +253,7 @@ def data(self) -> dict[str, Any]: "allows_cost_estimate": self.allows_cost_estimate, "allows_virtual_table_explore": self.allows_virtual_table_explore, "explore_database_id": self.explore_database_id, + "schema_options": self.schema_options, "parameters": self.parameters, "disable_data_preview": self.disable_data_preview, "parameters_schema": self.parameters_schema, @@ -838,7 +844,9 @@ def get_columns( self, table_name: str, schema: str | None = None ) -> list[ResultSetColumnType]: with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_columns(inspector, table_name, schema) + return self.db_engine_spec.get_columns( + inspector, table_name, schema, self.schema_options + ) def get_metrics( self, diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 953683b5dcd01..c71dcea3f1a2d 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -84,6 +84,8 @@ class ResultSetColumnType(TypedDict): scale: NotRequired[Any] max_length: NotRequired[Any] + query_as: NotRequired[Any] + CacheConfig = dict[str, Any] DbapiDescriptionRow = tuple[ diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 1b50a683a0841..15e55fc5af62f 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access +import copy import json from datetime import datetime from typing import Any, Optional @@ -24,9 +25,11 @@ import pytest from pytest_mock import MockerFixture from sqlalchemy import types +from trino.sqlalchemy import datatype import superset.config from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT +from superset.superset_typing import ResultSetColumnType, SQLAColumnType from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, @@ -35,6 +38,24 @@ from tests.unit_tests.fixtures.common import dttm +def _assert_columns_equal(actual_cols, expected_cols) -> None: + """ + Assert equality of the given cols, bearing in mind sqlalchemy type + instances can't be compared for equality, so will have to be converted to + strings first. + """ + actual = copy.deepcopy(actual_cols) + expected = copy.deepcopy(expected_cols) + + for col in actual: + col["type"] = str(col["type"]) + + for col in expected: + col["type"] = str(col["type"]) + + assert actual == expected + + @pytest.mark.parametrize( "extra,expected", [ @@ -395,3 +416,104 @@ def _mock_execute(*args, **kwargs): mock_query.set_extra_json_key.assert_called_once_with( key=QUERY_CANCEL_KEY, value=query_id ) + + +def test_get_columns(mocker: MockerFixture): + """Test that ROW columns are not expanded without expand_rows""" + from superset.db_engine_specs.trino import TrinoEngineSpec + + field1_type = datatype.parse_sqltype("row(a varchar, b date)") + field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") + field3_type = datatype.parse_sqltype("int") + + sqla_columns = [ + SQLAColumnType(name="field1", type=field1_type, is_dttm=False), + SQLAColumnType(name="field2", type=field2_type, is_dttm=False), + SQLAColumnType(name="field3", type=field3_type, is_dttm=False), + ] + mock_inspector = mocker.MagicMock() + mock_inspector.get_columns.return_value = sqla_columns + + actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema") + expected = [ + ResultSetColumnType( + name="field1", column_name="field1", type=field1_type, is_dttm=False + ), + ResultSetColumnType( + name="field2", column_name="field2", type=field2_type, is_dttm=False + ), + ResultSetColumnType( + name="field3", column_name="field3", type=field3_type, is_dttm=False + ), + ] + + _assert_columns_equal(actual, expected) + + +def test_get_columns_expand_rows(mocker: MockerFixture): + """Test that ROW columns are correctly expanded with expand_rows""" + from superset.db_engine_specs.trino import TrinoEngineSpec + + field1_type = datatype.parse_sqltype("row(a varchar, b date)") + field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") + field3_type = datatype.parse_sqltype("int") + + sqla_columns = [ + SQLAColumnType(name="field1", type=field1_type, is_dttm=False), + SQLAColumnType(name="field2", type=field2_type, is_dttm=False), + SQLAColumnType(name="field3", type=field3_type, is_dttm=False), + ] + mock_inspector = mocker.MagicMock() + mock_inspector.get_columns.return_value = sqla_columns + + actual = TrinoEngineSpec.get_columns( + mock_inspector, "table", "schema", {"expand_rows": True} + ) + expected = [ + ResultSetColumnType( + name="field1", column_name="field1", type=field1_type, is_dttm=False + ), + ResultSetColumnType( + name="field1.a", + column_name="field1.a", + type=types.VARCHAR(), + is_dttm=False, + query_as='"field1"."a" AS "field1.a"', + ), + ResultSetColumnType( + name="field1.b", + column_name="field1.b", + type=types.DATE(), + is_dttm=True, + query_as='"field1"."b" AS "field1.b"', + ), + ResultSetColumnType( + name="field2", column_name="field2", type=field2_type, is_dttm=False + ), + ResultSetColumnType( + name="field2.r1", + column_name="field2.r1", + type=datatype.parse_sqltype("row(a varchar, b varchar)"), + is_dttm=False, + query_as='"field2"."r1" AS "field2.r1"', + ), + ResultSetColumnType( + name="field2.r1.a", + column_name="field2.r1.a", + type=types.VARCHAR(), + is_dttm=False, + query_as='"field2"."r1"."a" AS "field2.r1.a"', + ), + ResultSetColumnType( + name="field2.r1.b", + column_name="field2.r1.b", + type=types.VARCHAR(), + is_dttm=False, + query_as='"field2"."r1"."b" AS "field2.r1.b"', + ), + ResultSetColumnType( + name="field3", column_name="field3", type=field3_type, is_dttm=False + ), + ] + + _assert_columns_equal(actual, expected)