From 05184c1c57d9a21fcfd7029234d59765b21e6689 Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Mon, 29 Jan 2024 10:03:24 -0800 Subject: [PATCH] feat: add chart id and dataset id to global logs (#26443) (cherry picked from commit 78dc6ce6c9514a1d161f4b5bdab4148e1e4a28a5) --- superset/charts/data/api.py | 21 +++++++- .../charts/data/api_tests.py | 53 +++++++++++++++++-- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index a62e6a2407451..f0bc98f253710 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -47,7 +47,13 @@ from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger from superset.models.sql_lab import Query -from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser +from superset.utils.core import ( + create_zip, + DatasourceType, + get_user_id, + json_int_dttm_ser, +) +from superset.utils.decorators import logs_context from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse from superset.views.base_api import statsd_metrics @@ -421,6 +427,19 @@ def _get_data_response( def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]: return QueryContextCacheLoader.load(cache_key) + def _map_form_data_datasource_to_dataset_id( + self, form_data: dict[str, Any] + ) -> dict[str, Any]: + return { + "dataset_id": form_data.get("datasource", {}).get("id") + if isinstance(form_data.get("datasource"), dict) + and form_data.get("datasource", {}).get("type") + == DatasourceType.TABLE.value + else None, + "slice_id": form_data.get("form_data", {}).get("slice_id"), + } + + @logs_context(context_func=_map_form_data_datasource_to_dataset_id) def _create_query_context_from_form( self, form_data: dict[str, Any] ) -> QueryContext: diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 4def03ff4e484..cc45243ee5448 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -27,6 +27,7 @@ from flask import Response from tests.integration_tests.conftest import with_feature_flags +from superset.charts.data.api import ChartDataRestApi from superset.models.sql_lab import Query from tests.integration_tests.base_tests import SupersetTestCase, test_client from tests.integration_tests.annotation_layers.fixtures import create_annotation_layers @@ -47,7 +48,6 @@ from superset.errors import SupersetErrorType from superset.extensions import async_query_manager_factory, db from superset.models.annotations import AnnotationLayer -from superset.models.slice import Slice from superset.superset_typing import AdhocColumn from superset.utils.core import ( AnnotationType, @@ -88,7 +88,9 @@ def setUp(self) -> None: BaseTestChartDataApi.query_context_payload_template = get_query_context( "birth_names" ) - self.query_context_payload = copy.deepcopy(self.query_context_payload_template) + self.query_context_payload = ( + copy.deepcopy(self.query_context_payload_template) or {} + ) def get_expected_row_count(self, client_id: str) -> int: start_date = datetime.now() @@ -124,7 +126,49 @@ def quote_name(self, name: str): @pytest.mark.chart_data_flow class TestPostChartDataApi(BaseTestChartDataApi): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_with_valid_qc__data_is_returned(self): + def test__map_form_data_datasource_to_dataset_id(self): + # arrange + self.query_context_payload["datasource"] = {"id": 1, "type": "table"} + # act + response = ChartDataRestApi._map_form_data_datasource_to_dataset_id( + ChartDataRestApi, self.query_context_payload + ) + # assert + assert response == {"dataset_id": 1, "slice_id": None} + + # takes malformed content without raising an error + self.query_context_payload["datasource"] = "1__table" + # act + response = ChartDataRestApi._map_form_data_datasource_to_dataset_id( + ChartDataRestApi, self.query_context_payload + ) + # assert + assert response == {"dataset_id": None, "slice_id": None} + + # takes a slice id + self.query_context_payload["datasource"] = None + self.query_context_payload["form_data"] = {"slice_id": 1} + # act + response = ChartDataRestApi._map_form_data_datasource_to_dataset_id( + ChartDataRestApi, self.query_context_payload + ) + # assert + assert response == {"dataset_id": None, "slice_id": 1} + + # takes missing slice id + self.query_context_payload["datasource"] = None + self.query_context_payload["form_data"] = {"foo": 1} + # act + response = ChartDataRestApi._map_form_data_datasource_to_dataset_id( + ChartDataRestApi, self.query_context_payload + ) + # assert + assert response == {"dataset_id": None, "slice_id": None} + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @mock.patch("superset.utils.decorators.g") + def test_with_valid_qc__data_is_returned(self, mock_g): + mock_g.logs_context = {} # arrange expected_row_count = self.get_expected_row_count("client_id_1") # act @@ -133,6 +177,9 @@ def test_with_valid_qc__data_is_returned(self): assert rv.status_code == 200 self.assert_row_count(rv, expected_row_count) + # check that global logs decorator is capturing from form_data + assert isinstance(mock_g.logs_context.get("dataset_id"), int) + @staticmethod def assert_row_count(rv: Response, expected_row_count: int): assert rv.json["result"][0]["rowcount"] == expected_row_count