Skip to content

Commit

Permalink
feat(dataset API): Add parameter to optionally render Jinja macros in…
Browse files Browse the repository at this point in the history
… API response (apache#30721)
  • Loading branch information
Vitor-Avila authored Oct 30, 2024
1 parent d5a98e0 commit e79778a
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 2 deletions.
150 changes: 149 additions & 1 deletion superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from __future__ import annotations

import logging
from datetime import datetime
from io import BytesIO
from typing import Any
from typing import Any, Callable
from zipfile import is_zipfile, ZipFile

from flask import request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.api.schemas import get_item_schema
from flask_appbuilder.const import (
API_RESULT_RES_KEY,
API_SELECT_COLUMNS_RIS_KEY,
)
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
from jinja2.exceptions import TemplateSyntaxError
from marshmallow import ValidationError

from superset import event_logger
Expand Down Expand Up @@ -65,6 +73,8 @@
GetOrCreateDatasetSchema,
openapi_spec_methods_override,
)
from superset.exceptions import SupersetTemplateException
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.utils import json
from superset.utils.core import parse_boolean_string
from superset.views.base import DatasourceFilter
Expand All @@ -75,6 +85,7 @@
requires_json,
statsd_metrics,
)
from superset.views.error_handling import handle_api_exception
from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1056,3 +1067,140 @@ def warm_up_cache(self) -> Response:
return self.response(200, result=result)
except CommandException as ex:
return self.response(ex.status, message=ex.message)

@expose("/<int:pk>", methods=("GET",))
@protect()
@safe
@rison(get_item_schema)
@statsd_metrics
@handle_api_exception
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".get",
log_to_statsd=False,
)
def get(self, pk: int, **kwargs: Any) -> Response:
"""Get a dataset.
---
get:
summary: Get a dataset
description: Get a dataset by ID
parameters:
- in: path
schema:
type: integer
description: The dataset ID
name: pk
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/get_item_schema'
- in: query
name: include_rendered_sql
description: >-
Should Jinja macros from sql, metrics and columns be rendered
and included in the response
schema:
type: boolean
responses:
200:
description: Dataset object has been returned.
content:
application/json:
schema:
type: object
properties:
id:
description: The item id
type: string
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.get'
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
item: SqlaTable | None = self.datamodel.get(
pk,
self._base_filters,
self.show_select_columns,
self.show_outer_default_load,
)
if not item:
return self.response_404()

response: dict[str, Any] = {}
args = kwargs.get("rison", {})
select_cols = args.get(API_SELECT_COLUMNS_RIS_KEY, [])
pruned_select_cols = [col for col in select_cols if col in self.show_columns]
self.set_response_key_mappings(
response,
self.get,
args,
**{API_SELECT_COLUMNS_RIS_KEY: pruned_select_cols},
)
if pruned_select_cols:
show_model_schema = self.model2schemaconverter.convert(pruned_select_cols)
else:
show_model_schema = self.show_model_schema

response["id"] = pk
response[API_RESULT_RES_KEY] = show_model_schema.dump(item, many=False)

if parse_boolean_string(request.args.get("include_rendered_sql")):
try:
processor = get_template_processor(database=item.database)
response["result"] = self.render_dataset_fields(
response["result"], processor
)
except SupersetTemplateException as ex:
return self.response_400(message=str(ex))
return self.response(200, **response)

@staticmethod
def render_dataset_fields(
data: dict[str, Any], processor: BaseTemplateProcessor
) -> dict[str, Any]:
"""
Renders Jinja macros in the ``sql``, ``metrics`` and ``columns`` fields.
:param data: Dataset info to be rendered
:param processor: A ``TemplateProcessor`` instance
:return: Rendered dataset data
"""

def render_item_list(item_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [
{
**item,
"rendered_expression": processor.process_template(
item["expression"]
),
}
if item.get("expression")
else item
for item in item_list
]

items: list[tuple[str, str, str, Callable[[Any], Any]]] = [
("query", "sql", "rendered_sql", processor.process_template),
("metric", "metrics", "metrics", render_item_list),
("calculated column", "columns", "columns", render_item_list),
]
for item_type, key, new_key, func in items:
if not data.get(key):
continue

try:
data[new_key] = func(data[key])
except TemplateSyntaxError as ex:
raise SupersetTemplateException(
f"Unable to render expression from dataset {item_type}.",
) from ex

return data
1 change: 0 additions & 1 deletion superset/datasets/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}

openapi_spec_methods_override = {
"get": {"get": {"summary": "Get a dataset detail information"}},
"get_list": {
"get": {
"summary": "Get a list of datasets",
Expand Down
139 changes: 139 additions & 0 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,145 @@ def test_get_dataset_item(self):
assert len(response["result"]["columns"]) == 3
assert len(response["result"]["metrics"]) == 2

def test_get_dataset_render_jinja(self):
"""
Dataset API: Test get dataset with the render parameter.
"""
database = get_example_database()
dataset = SqlaTable(
table_name="test_sql_table_with_jinja",
database=database,
schema=get_example_default_schema(),
main_dttm_col="default_dttm",
columns=[
TableColumn(
column_name="my_user_id",
type="INTEGER",
is_dttm=False,
),
TableColumn(
column_name="calculated_test",
type="VARCHAR(255)",
is_dttm=False,
expression="'{{ current_username() }}'",
),
],
metrics=[
SqlMetric(
metric_name="param_test",
expression="{{ url_param('multiplier') }} * 1.4",
)
],
sql="SELECT {{ current_user_id() }} as my_user_id",
)
db.session.add(dataset)
db.session.commit()

self.login(ADMIN_USERNAME)
admin = self.get_user(ADMIN_USERNAME)
uri = (
f"api/v1/dataset/{dataset.id}?"
"q=(columns:!(id,sql,columns.column_name,columns.expression,metrics.metric_name,metrics.expression))"
"&include_rendered_sql=true&multiplier=4"
)
rv = self.get_assert_metric(uri, "get")
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))

assert response["result"] == {
"id": dataset.id,
"sql": "SELECT {{ current_user_id() }} as my_user_id",
"rendered_sql": f"SELECT {admin.id} as my_user_id",
"columns": [
{
"column_name": "my_user_id",
"expression": None,
},
{
"column_name": "calculated_test",
"expression": "'{{ current_username() }}'",
"rendered_expression": f"'{admin.username}'",
},
],
"metrics": [
{
"metric_name": "param_test",
"expression": "{{ url_param('multiplier') }} * 1.4",
"rendered_expression": "4 * 1.4",
},
],
}

db.session.delete(dataset)
db.session.commit()

def test_get_dataset_render_jinja_exceptions(self):
"""
Dataset API: Test get dataset with the render parameter
when rendering raises an exception.
"""
database = get_example_database()
dataset = SqlaTable(
table_name="test_sql_table_with_incorrect_jinja",
database=database,
schema=get_example_default_schema(),
main_dttm_col="default_dttm",
columns=[
TableColumn(
column_name="my_user_id",
type="INTEGER",
is_dttm=False,
),
TableColumn(
column_name="calculated_test",
type="VARCHAR(255)",
is_dttm=False,
expression="'{{ current_username() }'",
),
],
metrics=[
SqlMetric(
metric_name="param_test",
expression="{{ url_param('multiplier') } * 1.4",
)
],
sql="SELECT {{ current_user_id() } as my_user_id",
)
db.session.add(dataset)
db.session.commit()

self.login(ADMIN_USERNAME)

uri = f"api/v1/dataset/{dataset.id}?q=(columns:!(id,sql))&include_rendered_sql=true"
rv = self.get_assert_metric(uri, "get")
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
assert response["message"] == "Unable to render expression from dataset query."

uri = (
f"api/v1/dataset/{dataset.id}?q=(columns:!(id,metrics.expression))"
"&include_rendered_sql=true&multiplier=4"
)
rv = self.get_assert_metric(uri, "get")
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
assert response["message"] == "Unable to render expression from dataset metric."

uri = (
f"api/v1/dataset/{dataset.id}?q=(columns:!(id,columns.expression))"
"&include_rendered_sql=true"
)
rv = self.get_assert_metric(uri, "get")
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
assert (
response["message"]
== "Unable to render expression from dataset calculated column."
)

db.session.delete(dataset)
db.session.commit()

def test_get_dataset_distinct_schema(self):
"""
Dataset API: Test get dataset distinct schema
Expand Down

0 comments on commit e79778a

Please sign in to comment.