diff --git a/superset/datasets/api.py b/superset/datasets/api.py index f8f6bdc0b9604..762727aafcdff 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -15,16 +15,24 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=too-many-lines +from __future__ import annotations + import logging from datetime import datetime from io import BytesIO -from typing import Any +from typing import Any, Callable from zipfile import is_zipfile, ZipFile from flask import request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe +from flask_appbuilder.api.schemas import get_item_schema +from flask_appbuilder.const import ( + API_RESULT_RES_KEY, + API_SELECT_COLUMNS_RIS_KEY, +) from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext +from jinja2.exceptions import TemplateSyntaxError from marshmallow import ValidationError from superset import event_logger @@ -65,6 +73,8 @@ GetOrCreateDatasetSchema, openapi_spec_methods_override, ) +from superset.exceptions import SupersetTemplateException +from superset.jinja_context import BaseTemplateProcessor, get_template_processor from superset.utils import json from superset.utils.core import parse_boolean_string from superset.views.base import DatasourceFilter @@ -75,6 +85,7 @@ requires_json, statsd_metrics, ) +from superset.views.error_handling import handle_api_exception from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners logger = logging.getLogger(__name__) @@ -1056,3 +1067,140 @@ def warm_up_cache(self) -> Response: return self.response(200, result=result) except CommandException as ex: return self.response(ex.status, message=ex.message) + + @expose("/", methods=("GET",)) + @protect() + @safe + @rison(get_item_schema) + @statsd_metrics + @handle_api_exception + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".get", + log_to_statsd=False, + ) + def get(self, pk: int, **kwargs: Any) -> Response: + """Get a dataset. + --- + get: + summary: Get a dataset + description: Get a dataset by ID + parameters: + - in: path + schema: + type: integer + description: The dataset ID + name: pk + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/get_item_schema' + - in: query + name: include_rendered_sql + description: >- + Should Jinja macros from sql, metrics and columns be rendered + and included in the response + schema: + type: boolean + responses: + 200: + description: Dataset object has been returned. + content: + application/json: + schema: + type: object + properties: + id: + description: The item id + type: string + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.get' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + item: SqlaTable | None = self.datamodel.get( + pk, + self._base_filters, + self.show_select_columns, + self.show_outer_default_load, + ) + if not item: + return self.response_404() + + response: dict[str, Any] = {} + args = kwargs.get("rison", {}) + select_cols = args.get(API_SELECT_COLUMNS_RIS_KEY, []) + pruned_select_cols = [col for col in select_cols if col in self.show_columns] + self.set_response_key_mappings( + response, + self.get, + args, + **{API_SELECT_COLUMNS_RIS_KEY: pruned_select_cols}, + ) + if pruned_select_cols: + show_model_schema = self.model2schemaconverter.convert(pruned_select_cols) + else: + show_model_schema = self.show_model_schema + + response["id"] = pk + response[API_RESULT_RES_KEY] = show_model_schema.dump(item, many=False) + + if parse_boolean_string(request.args.get("include_rendered_sql")): + try: + processor = get_template_processor(database=item.database) + response["result"] = self.render_dataset_fields( + response["result"], processor + ) + except SupersetTemplateException as ex: + return self.response_400(message=str(ex)) + return self.response(200, **response) + + @staticmethod + def render_dataset_fields( + data: dict[str, Any], processor: BaseTemplateProcessor + ) -> dict[str, Any]: + """ + Renders Jinja macros in the ``sql``, ``metrics`` and ``columns`` fields. + + :param data: Dataset info to be rendered + :param processor: A ``TemplateProcessor`` instance + :return: Rendered dataset data + """ + + def render_item_list(item_list: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [ + { + **item, + "rendered_expression": processor.process_template( + item["expression"] + ), + } + if item.get("expression") + else item + for item in item_list + ] + + items: list[tuple[str, str, str, Callable[[Any], Any]]] = [ + ("query", "sql", "rendered_sql", processor.process_template), + ("metric", "metrics", "metrics", render_item_list), + ("calculated column", "columns", "columns", render_item_list), + ] + for item_type, key, new_key, func in items: + if not data.get(key): + continue + + try: + data[new_key] = func(data[key]) + except TemplateSyntaxError as ex: + raise SupersetTemplateException( + f"Unable to render expression from dataset {item_type}.", + ) from ex + + return data diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 5b899d8402f23..4b7e92d7ff5bb 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -29,7 +29,6 @@ get_export_ids_schema = {"type": "array", "items": {"type": "integer"}} openapi_spec_methods_override = { - "get": {"get": {"summary": "Get a dataset detail information"}}, "get_list": { "get": { "summary": "Get a list of datasets", diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 49110277bf328..b04d4cec73692 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -410,6 +410,145 @@ def test_get_dataset_item(self): assert len(response["result"]["columns"]) == 3 assert len(response["result"]["metrics"]) == 2 + def test_get_dataset_render_jinja(self): + """ + Dataset API: Test get dataset with the render parameter. + """ + database = get_example_database() + dataset = SqlaTable( + table_name="test_sql_table_with_jinja", + database=database, + schema=get_example_default_schema(), + main_dttm_col="default_dttm", + columns=[ + TableColumn( + column_name="my_user_id", + type="INTEGER", + is_dttm=False, + ), + TableColumn( + column_name="calculated_test", + type="VARCHAR(255)", + is_dttm=False, + expression="'{{ current_username() }}'", + ), + ], + metrics=[ + SqlMetric( + metric_name="param_test", + expression="{{ url_param('multiplier') }} * 1.4", + ) + ], + sql="SELECT {{ current_user_id() }} as my_user_id", + ) + db.session.add(dataset) + db.session.commit() + + self.login(ADMIN_USERNAME) + admin = self.get_user(ADMIN_USERNAME) + uri = ( + f"api/v1/dataset/{dataset.id}?" + "q=(columns:!(id,sql,columns.column_name,columns.expression,metrics.metric_name,metrics.expression))" + "&include_rendered_sql=true&multiplier=4" + ) + rv = self.get_assert_metric(uri, "get") + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + + assert response["result"] == { + "id": dataset.id, + "sql": "SELECT {{ current_user_id() }} as my_user_id", + "rendered_sql": f"SELECT {admin.id} as my_user_id", + "columns": [ + { + "column_name": "my_user_id", + "expression": None, + }, + { + "column_name": "calculated_test", + "expression": "'{{ current_username() }}'", + "rendered_expression": f"'{admin.username}'", + }, + ], + "metrics": [ + { + "metric_name": "param_test", + "expression": "{{ url_param('multiplier') }} * 1.4", + "rendered_expression": "4 * 1.4", + }, + ], + } + + db.session.delete(dataset) + db.session.commit() + + def test_get_dataset_render_jinja_exceptions(self): + """ + Dataset API: Test get dataset with the render parameter + when rendering raises an exception. + """ + database = get_example_database() + dataset = SqlaTable( + table_name="test_sql_table_with_incorrect_jinja", + database=database, + schema=get_example_default_schema(), + main_dttm_col="default_dttm", + columns=[ + TableColumn( + column_name="my_user_id", + type="INTEGER", + is_dttm=False, + ), + TableColumn( + column_name="calculated_test", + type="VARCHAR(255)", + is_dttm=False, + expression="'{{ current_username() }'", + ), + ], + metrics=[ + SqlMetric( + metric_name="param_test", + expression="{{ url_param('multiplier') } * 1.4", + ) + ], + sql="SELECT {{ current_user_id() } as my_user_id", + ) + db.session.add(dataset) + db.session.commit() + + self.login(ADMIN_USERNAME) + + uri = f"api/v1/dataset/{dataset.id}?q=(columns:!(id,sql))&include_rendered_sql=true" + rv = self.get_assert_metric(uri, "get") + assert rv.status_code == 400 + response = json.loads(rv.data.decode("utf-8")) + assert response["message"] == "Unable to render expression from dataset query." + + uri = ( + f"api/v1/dataset/{dataset.id}?q=(columns:!(id,metrics.expression))" + "&include_rendered_sql=true&multiplier=4" + ) + rv = self.get_assert_metric(uri, "get") + assert rv.status_code == 400 + response = json.loads(rv.data.decode("utf-8")) + assert response["message"] == "Unable to render expression from dataset metric." + + uri = ( + f"api/v1/dataset/{dataset.id}?q=(columns:!(id,columns.expression))" + "&include_rendered_sql=true" + ) + rv = self.get_assert_metric(uri, "get") + assert rv.status_code == 400 + response = json.loads(rv.data.decode("utf-8")) + assert ( + response["message"] + == "Unable to render expression from dataset calculated column." + ) + + db.session.delete(dataset) + db.session.commit() + def test_get_dataset_distinct_schema(self): """ Dataset API: Test get dataset distinct schema