diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 75354ad355d3d..fb7409adba589 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -116,7 +116,6 @@ ) from superset.utils import core as utils, json from superset.utils.backports import StrEnum -from superset.utils.core import GenericDataType, is_adhoc_column, MediumText config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -477,7 +476,7 @@ def data_for_slices( # pylint: disable=too-many-locals ] filtered_columns: list[Column] = [] - column_types: set[GenericDataType] = set() + column_types: set[utils.GenericDataType] = set() for column_ in data["columns"]: generic_type = column_.get("type_generic") if generic_type is not None: @@ -511,7 +510,7 @@ def data_for_slices( # pylint: disable=too-many-locals def filter_values_handler( # pylint: disable=too-many-arguments values: FilterValues | None, operator: str, - target_generic_type: GenericDataType, + target_generic_type: utils.GenericDataType, target_native_type: str | None = None, is_list_target: bool = False, db_engine_spec: builtins.type[BaseEngineSpec] | None = None, @@ -829,10 +828,10 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod advanced_data_type = Column(String(255)) groupby = Column(Boolean, default=True) filterable = Column(Boolean, default=True) - description = Column(MediumText()) + description = Column(utils.MediumText()) table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE")) is_dttm = Column(Boolean, default=False) - expression = Column(MediumText()) + expression = Column(utils.MediumText()) python_date_format = Column(String(255)) extra = Column(Text) @@ -892,21 +891,21 @@ def is_boolean(self) -> bool: """ Check if the column has a boolean datatype. """ - return self.type_generic == GenericDataType.BOOLEAN + return self.type_generic == utils.GenericDataType.BOOLEAN @property def is_numeric(self) -> bool: """ Check if the column has a numeric datatype. """ - return self.type_generic == GenericDataType.NUMERIC + return self.type_generic == utils.GenericDataType.NUMERIC @property def is_string(self) -> bool: """ Check if the column has a string datatype. """ - return self.type_generic == GenericDataType.STRING + return self.type_generic == utils.GenericDataType.STRING @property def is_temporal(self) -> bool: @@ -918,7 +917,7 @@ def is_temporal(self) -> bool: """ if self.is_dttm is not None: return self.is_dttm - return self.type_generic == GenericDataType.TEMPORAL + return self.type_generic == utils.GenericDataType.TEMPORAL @property def database(self) -> Database: @@ -935,7 +934,7 @@ def db_extra(self) -> dict[str, Any]: @property def type_generic(self) -> utils.GenericDataType | None: if self.is_dttm: - return GenericDataType.TEMPORAL + return utils.GenericDataType.TEMPORAL return ( column_spec.generic_type @@ -1038,12 +1037,12 @@ class SqlMetric(AuditMixinNullable, ImportExportMixin, CertificationMixin, Model metric_name = Column(String(255), nullable=False) verbose_name = Column(String(1024)) metric_type = Column(String(32)) - description = Column(MediumText()) + description = Column(utils.MediumText()) d3format = Column(String(128)) currency = Column(String(128)) warning_text = Column(Text) table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE")) - expression = Column(MediumText(), nullable=False) + expression = Column(utils.MediumText(), nullable=False) extra = Column(Text) table: Mapped[SqlaTable] = relationship( @@ -1185,7 +1184,7 @@ class SqlaTable( ) schema = Column(String(255)) catalog = Column(String(256), nullable=True, default=None) - sql = Column(MediumText()) + sql = Column(utils.MediumText()) is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) extra = Column(Text) @@ -1980,10 +1979,26 @@ def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool: templatable_statements.append(extras["where"]) if "having" in extras: templatable_statements.append(extras["having"]) - if "columns" in query_obj: - templatable_statements += [ - c["sqlExpression"] for c in query_obj["columns"] if is_adhoc_column(c) - ] + if columns := query_obj.get("columns"): + calculated_columns: dict[str, Any] = { + c.column_name: c.expression for c in self.columns if c.expression + } + for column_ in columns: + if utils.is_adhoc_column(column_): + templatable_statements.append(column_["sqlExpression"]) + elif isinstance(column_, str) and column_ in calculated_columns: + templatable_statements.append(calculated_columns[column_]) + if metrics := query_obj.get("metrics"): + metrics_by_name: dict[str, Any] = { + m.metric_name: m.expression for m in self.metrics + } + for metric in metrics: + if utils.is_adhoc_metric(metric) and ( + sql := metric.get("sqlExpression") + ): + templatable_statements.append(sql) + elif isinstance(metric, str) and metric in metrics_by_name: + templatable_statements.append(metrics_by_name[metric]) if self.is_rls_supported: templatable_statements += [ f.clause for f in security_manager.get_rls_filters(self) @@ -2125,4 +2140,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): secondary=RLSFilterTables, backref="row_level_security_filters", ) - clause = Column(MediumText(), nullable=False) + clause = Column(utils.MediumText(), nullable=False) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 79d4bf00ed028..d4ca3bc1c1a47 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file +from __future__ import annotations + import re from datetime import datetime -from typing import Any, NamedTuple, Optional, Union +from typing import Any, Literal, NamedTuple, Optional, Union from re import Pattern -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest import numpy as np @@ -913,54 +915,99 @@ def test_extra_cache_keys_in_sql_expression( @pytest.mark.usefixtures("app_context") @pytest.mark.parametrize( - "sql_expression,expected_cache_keys,has_extra_cache_keys", + "sql_expression,expected_cache_keys,has_extra_cache_keys,item_type", [ - ("'{{ current_username() }}'", ["abc"], True), - ("(user != 'abc')", [], False), + ("'{{ current_username() }}'", ["abc"], True, "columns"), + ("(user != 'abc')", [], False, "columns"), + ("{{ current_user_id() }}", [1], True, "metrics"), + ("COUNT(*)", [], False, "metrics"), ], ) @patch("superset.jinja_context.get_user_id", return_value=1) @patch("superset.jinja_context.get_username", return_value="abc") -@patch("superset.jinja_context.get_user_email", return_value="abc@test.com") -def test_extra_cache_keys_in_columns( - mock_user_email, - mock_username, - mock_user_id, - sql_expression, - expected_cache_keys, - has_extra_cache_keys, +def test_extra_cache_keys_in_adhoc_metrics_and_columns( + mock_username: Mock, + mock_user_id: Mock, + sql_expression: str, + expected_cache_keys: list[str | None], + has_extra_cache_keys: bool, + item_type: Literal["columns", "metrics"], ): table = SqlaTable( table_name="test_has_no_extra_cache_keys_table", sql="SELECT 'abc' as user", database=get_example_database(), ) - base_query_obj = { + base_query_obj: dict[str, Any] = { "granularity": None, "from_dttm": None, "to_dttm": None, "groupby": [], "metrics": [], + "columns": [], "is_timeseries": False, "filter": [], } - query_obj = dict( - **base_query_obj, - columns=[ + items: dict[str, Any] = { + item_type: [ { "label": None, "expressionType": "SQL", "sqlExpression": sql_expression, } ], - ) + } + + query_obj = {**base_query_obj, **items} extra_cache_keys = table.get_extra_cache_keys(query_obj) assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys assert extra_cache_keys == expected_cache_keys +@pytest.mark.usefixtures("app_context") +@patch("superset.jinja_context.get_user_id", return_value=1) +@patch("superset.jinja_context.get_username", return_value="abc") +def test_extra_cache_keys_in_dataset_metrics_and_columns( + mock_username: Mock, + mock_user_id: Mock, +): + table = SqlaTable( + table_name="test_has_no_extra_cache_keys_table", + sql="SELECT 'abc' as user", + database=get_example_database(), + columns=[ + TableColumn(column_name="user", type="VARCHAR(255)"), + TableColumn( + column_name="username", + type="VARCHAR(255)", + expression="{{ current_username() }}", + ), + ], + metrics=[ + SqlMetric( + metric_name="variable_profit", + expression="SUM(price) * {{ url_param('multiplier') }}", + ), + ], + ) + query_obj: dict[str, Any] = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": [], + "columns": ["username"], + "metrics": ["variable_profit"], + "is_timeseries": False, + "filter": [], + } + + extra_cache_keys = table.get_extra_cache_keys(query_obj) + assert table.has_extra_cache_key_calls(query_obj) is True + assert set(extra_cache_keys) == {"abc", None} + + @pytest.mark.usefixtures("app_context") @pytest.mark.parametrize( "row,dimension,result",