From 601e55656c437091007a8a51dabed6ba440d792b Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Fri, 6 Sep 2024 11:15:40 -0700 Subject: [PATCH] feat(jinja): add advanced temporal filter functionality (#30142) --- docs/docs/configuration/sql-templating.mdx | 73 ++++++++- superset/jinja_context.py | 103 ++++++++++++- tests/unit_tests/jinja_context_test.py | 164 +++++++++++++++++++++ 3 files changed, 335 insertions(+), 5 deletions(-) diff --git a/docs/docs/configuration/sql-templating.mdx b/docs/docs/configuration/sql-templating.mdx index 4fbddfe530660..c171f678db7f5 100644 --- a/docs/docs/configuration/sql-templating.mdx +++ b/docs/docs/configuration/sql-templating.mdx @@ -17,8 +17,8 @@ made available in the Jinja context: - `columns`: columns which to group by in the query - `filter`: filters applied in the query -- `from_dttm`: start `datetime` value from the selected time range (`None` if undefined) -- `to_dttm`: end `datetime` value from the selected time range (`None` if undefined) +- `from_dttm`: start `datetime` value from the selected time range (`None` if undefined) (deprecated beginning in version 5.0, use `get_time_filter` instead) +- `to_dttm`: end `datetime` value from the selected time range (`None` if undefined). (deprecated beginning in version 5.0, use `get_time_filter` instead) - `groupby`: columns which to group by in the query (deprecated) - `metrics`: aggregate expressions in the query - `row_limit`: row limit of the query @@ -346,6 +346,75 @@ Here's a concrete example: order by lineage, level ``` +**Time Filter** + +The `{{ get_time_filter() }}` macro returns the time filter applied to a specific column. This is useful if you want +to handle time filters inside the virtual dataset, as by default the time filter is placed on the outer query. This can +considerably improve performance, as many databases and query engines are able to optimize the query better +if the temporal filter is placed on the inner query, as opposed to the outer query. + +The macro takes the following parameters: +- `column`: Name of the temporal column. Leave undefined to reference the time range from a Dashboard Native Time Range + filter (when present). +- `default`: The default value to fall back to if the time filter is not present, or has the value `No filter` +- `target_type`: The target temporal type as recognized by the target database (e.g. `TIMESTAMP`, `DATE` or + `DATETIME`). If `column` is defined, the format will default to the type of the column. This is used to produce + the format of the `from_expr` and `to_expr` properties of the returned `TimeFilter` object. +- `remove_filter`: When set to true, mark the filter as processed, removing it from the outer query. Useful when a + filter should only apply to the inner query. + +The return type has the following properties: +- `from_expr`: the start of the time filter (if any) +- `to_expr`: the end of the time filter (if any) +- `time_range`: The applied time range + +Here's a concrete example using the `logs` table from the Superset metastore: + +``` +{% set time_filter = get_time_filter("dttm", remove_filter=True) %} +{% set from_expr = time_filter.from_expr %} +{% set to_expr = time_filter.to_expr %} +{% set time_range = time_filter.time_range %} +SELECT + *, + '{{ time_range }}' as time_range +FROM logs +{% if from_expr or to_expr %}WHERE 1 = 1 +{% if from_expr %}AND dttm >= {{ from_expr }}{% endif %} +{% if to_expr %}AND dttm < {{ to_expr }}{% endif %} +{% endif %} +``` + +Assuming we are creating a table chart with a simple `COUNT(*)` as the metric with a time filter `Last week` on the +`dttm` column, this would render the following query on Postgres (note the formatting of the temporal filters, and +the absence of time filters on the outer query): + +``` +SELECT COUNT(*) AS count +FROM + (SELECT *, + 'Last week' AS time_range + FROM public.logs + WHERE 1 = 1 + AND dttm >= TO_TIMESTAMP('2024-08-27 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US') + AND dttm < TO_TIMESTAMP('2024-09-03 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')) AS virtual_table +ORDER BY count DESC +LIMIT 1000; +``` + +When using the `default` parameter, the templated query can be simplified, as the endpoints will always be defined +(to use a fixed time range, you can also use something like `default="2024-08-27 : 2024-09-03"`) +``` +{% set time_filter = get_time_filter("dttm", default="Last week", remove_filter=True) %} +SELECT + *, + '{{ time_filter.time_range }}' as time_range +FROM logs +WHERE + dttm >= {{ time_filter.from_expr }} + AND dttm < {{ time_filter.to_expr }} +``` + **Datasets** It's possible to query physical and virtual datasets using the `dataset` macro. This is useful if you've defined computed columns and metrics on your datasets, and want to reuse the definition in adhoc SQL Lab queries. diff --git a/superset/jinja_context.py b/superset/jinja_context.py index a2c87db765071..625e59fe6ae38 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -16,7 +16,10 @@ # under the License. """Defines the templating context for SQL Lab""" +from __future__ import annotations + import re +from dataclasses import dataclass from datetime import datetime from functools import lru_cache, partial from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Union @@ -31,13 +34,16 @@ from sqlalchemy.types import String from superset.commands.dataset.exceptions import DatasetNotFoundError -from superset.constants import LRU_CACHE_MAX_SIZE +from superset.common.utils.time_range_utils import get_since_until_from_time_range +from superset.constants import LRU_CACHE_MAX_SIZE, NO_TIME_RANGE from superset.exceptions import SupersetTemplateException from superset.extensions import feature_flag_manager from superset.sql_parse import Table from superset.utils import json from superset.utils.core import ( + AdhocFilterClause, convert_legacy_filters_into_adhoc, + FilterOperator, get_user_email, get_user_id, get_username, @@ -62,6 +68,7 @@ "dict", "tuple", "set", + "TimeFilter", ) COLLECTION_TYPES = ("list", "dict", "tuple", "set") @@ -77,6 +84,17 @@ class Filter(TypedDict): val: Union[None, Any, list[Any]] +@dataclass +class TimeFilter: + """ + Container for temporal filter. + """ + + from_expr: str | None + to_expr: str | None + time_range: str | None + + class ExtraCache: """ Dummy class that exposes a method used to store additional values used in @@ -95,17 +113,21 @@ class ExtraCache: r").*\}\}" ) - def __init__( + def __init__( # pylint: disable=too-many-arguments self, extra_cache_keys: Optional[list[Any]] = None, applied_filters: Optional[list[str]] = None, removed_filters: Optional[list[str]] = None, + database: Optional[Database] = None, dialect: Optional[Dialect] = None, + table: Optional[SqlaTable] = None, ): self.extra_cache_keys = extra_cache_keys self.applied_filters = applied_filters if applied_filters is not None else [] self.removed_filters = removed_filters if removed_filters is not None else [] + self.database = database self.dialect = dialect + self.table = table def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]: """ @@ -319,7 +341,6 @@ def get_filters(self, column: str, remove_filter: bool = False) -> list[Filter]: :return: returns a list of filters """ # pylint: disable=import-outside-toplevel - from superset.utils.core import FilterOperator from superset.views.utils import get_form_data form_data, _ = get_form_data() @@ -354,6 +375,77 @@ def get_filters(self, column: str, remove_filter: bool = False) -> list[Filter]: return filters + def get_time_filter( + self, + column: str | None = None, + default: str | None = None, + target_type: str | None = None, + remove_filter: bool = False, + ) -> TimeFilter: + """Get the time filter with appropriate formatting, + either for a specific column, or whichever time range is being emitted + from a dashboard. + + :param column: Name of the temporal column. Leave undefined to reference the + time range from a Dashboard Native Time Range filter (when present). + :param default: The default value to fall back to if the time filter is + not present, or has the value `No filter` + :param target_type: The target temporal type as recognized by the target + database (e.g. `TIMESTAMP`, `DATE` or `DATETIME`). If `column` is defined, + the format will default to the type of the column. This is used to produce + the format of the `from_expr` and `to_expr` properties of the returned + `TimeFilter` object. + :param remove_filter: When set to true, mark the filter as processed, + removing it from the outer query. Useful when a filter should + only apply to the inner query. + :return: The corresponding time filter. + """ + # pylint: disable=import-outside-toplevel + from superset.views.utils import get_form_data + + form_data, _ = get_form_data() + convert_legacy_filters_into_adhoc(form_data) + merge_extra_filters(form_data) + time_range = form_data.get("time_range") + if column: + flt: AdhocFilterClause | None = next( + ( + flt + for flt in form_data.get("adhoc_filters", []) + if flt["operator"] == FilterOperator.TEMPORAL_RANGE + and flt["subject"] == column + ), + None, + ) + if flt: + if remove_filter: + if column not in self.removed_filters: + self.removed_filters.append(column) + if column not in self.applied_filters: + self.applied_filters.append(column) + + time_range = cast(str, flt["comparator"]) + if not target_type and self.table: + target_type = self.table.columns_types.get(column) + + time_range = time_range or NO_TIME_RANGE + if time_range == NO_TIME_RANGE and default: + time_range = default + from_expr, to_expr = get_since_until_from_time_range(time_range) + + def _format_dttm(dttm: datetime | None) -> str | None: + return ( + self.database.db_engine_spec.convert_dttm(target_type or "", dttm) + if self.database and dttm + else None + ) + + return TimeFilter( + from_expr=_format_dttm(from_expr), + to_expr=_format_dttm(to_expr), + time_range=time_range, + ) + def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: return_value = func(*args, **kwargs) @@ -477,6 +569,7 @@ def __init__( self._schema = query.schema elif table: self._schema = table.schema + self._table = table self._extra_cache_keys = extra_cache_keys self._applied_filters = applied_filters self._removed_filters = removed_filters @@ -525,7 +618,9 @@ def set_context(self, **kwargs: Any) -> None: extra_cache_keys=self._extra_cache_keys, applied_filters=self._applied_filters, removed_filters=self._removed_filters, + database=self._database, dialect=self._database.get_dialect(), + table=self._table, ) from_dttm = ( @@ -544,6 +639,7 @@ def set_context(self, **kwargs: Any) -> None: from_dttm=from_dttm, to_dttm=to_dttm, ) + self._context.update( { "url_param": partial(safe_proxy, extra_cache.url_param), @@ -557,6 +653,7 @@ def set_context(self, **kwargs: Any) -> None: "get_filters": partial(safe_proxy, extra_cache.get_filters), "dataset": partial(safe_proxy, dataset_macro_with_context), "metric": partial(safe_proxy, metric_macro), + "get_time_filter": partial(safe_proxy, extra_cache.get_time_filter), } ) diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index a1ceaa08b2c0e..e13c4dcc339bf 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument +from __future__ import annotations from typing import Any import pytest +from freezegun import freeze_time from pytest_mock import MockerFixture from sqlalchemy.dialects import mysql from sqlalchemy.dialects.postgresql import dialect @@ -32,6 +34,7 @@ ExtraCache, metric_macro, safe_proxy, + TimeFilter, WhereInMacro, ) from superset.models.core import Database @@ -836,3 +839,164 @@ def test_metric_macro_no_dataset_id_with_context_chart_no_datasource_id( ) mock_get_form_data.assert_called_once() DatasetDAO.find_by_id.assert_not_called() + + +@pytest.mark.parametrize( + "description,args,kwargs,sqlalchemy_uri,queries,time_filter,removed_filters,applied_filters", + [ + ( + "Missing time_range and filter will return a No filter result", + [], + {"target_type": "TIMESTAMP"}, + "postgresql://mydb", + [{}], + TimeFilter( + from_expr=None, + to_expr=None, + time_range="No filter", + ), + [], + [], + ), + ( + "Missing time range and filter with default value will return a result with the defaults", + [], + {"default": "Last week", "target_type": "TIMESTAMP"}, + "postgresql://mydb", + [{}], + TimeFilter( + from_expr="TO_TIMESTAMP('2024-08-27 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')", + to_expr="TO_TIMESTAMP('2024-09-03 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')", + time_range="Last week", + ), + [], + [], + ), + ( + "Time range is extracted with the expected format, and default is ignored", + [], + {"default": "Last month", "target_type": "TIMESTAMP"}, + "postgresql://mydb", + [{"time_range": "Last week"}], + TimeFilter( + from_expr="TO_TIMESTAMP('2024-08-27 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')", + to_expr="TO_TIMESTAMP('2024-09-03 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')", + time_range="Last week", + ), + [], + [], + ), + ( + "Filter is extracted with the native format of the column (TIMESTAMP)", + ["dttm"], + {}, + "postgresql://mydb", + [ + { + "filters": [ + { + "col": "dttm", + "op": "TEMPORAL_RANGE", + "val": "Last week", + }, + ], + } + ], + TimeFilter( + from_expr="TO_TIMESTAMP('2024-08-27 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')", + to_expr="TO_TIMESTAMP('2024-09-03 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')", + time_range="Last week", + ), + [], + ["dttm"], + ), + ( + "Filter is extracted with the native format of the column (DATE)", + ["dt"], + {"remove_filter": True}, + "postgresql://mydb", + [ + { + "filters": [ + { + "col": "dt", + "op": "TEMPORAL_RANGE", + "val": "Last week", + }, + ], + } + ], + TimeFilter( + from_expr="TO_DATE('2024-08-27', 'YYYY-MM-DD')", + to_expr="TO_DATE('2024-09-03', 'YYYY-MM-DD')", + time_range="Last week", + ), + ["dt"], + ["dt"], + ), + ( + "Filter is extracted with the overridden format (TIMESTAMP to DATE)", + ["dttm"], + {"target_type": "DATE", "remove_filter": True}, + "trino://mydb", + [ + { + "filters": [ + { + "col": "dttm", + "op": "TEMPORAL_RANGE", + "val": "Last month", + }, + ], + } + ], + TimeFilter( + from_expr="DATE '2024-08-03'", + to_expr="DATE '2024-09-03'", + time_range="Last month", + ), + ["dttm"], + ["dttm"], + ), + ], +) +def test_get_time_filter( + description: str, + args: list[Any], + kwargs: dict[str, Any], + sqlalchemy_uri: str, + queries: list[Any] | None, + time_filter: TimeFilter, + removed_filters: list[str], + applied_filters: list[str], +) -> None: + """ + Test the ``get_time_filter`` macro. + """ + columns = [ + TableColumn(column_name="dt", is_dttm=1, type="DATE"), + TableColumn(column_name="dttm", is_dttm=1, type="TIMESTAMP"), + ] + + database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri) + table = SqlaTable( + table_name="my_dataset", + columns=columns, + main_dttm_col="dt", + database=database, + ) + + with ( + freeze_time("2024-09-03"), + app.test_request_context( + json={"queries": queries}, + ), + ): + cache = ExtraCache( + database=database, + table=table, + ) + + assert cache.get_time_filter(*args, **kwargs) == time_filter, description + assert cache.removed_filters == removed_filters + assert cache.applied_filters == applied_filters