From 33b934cbb346b464260dc1e2f4218713595a07e1 Mon Sep 17 00:00:00 2001 From: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com> Date: Thu, 11 Jul 2024 12:26:36 -0300 Subject: [PATCH] fix(Tags filter): Filter assets by tag ID (#29412) --- .../src/components/ListView/types.ts | 9 +- .../src/pages/ChartList/index.tsx | 2 +- .../src/pages/DashboardList/index.tsx | 2 +- .../src/pages/SavedQueryList/index.tsx | 2 +- superset/charts/api.py | 5 +- superset/charts/filters.py | 19 ++- superset/dashboards/api.py | 5 +- superset/dashboards/filters.py | 19 ++- superset/queries/saved_queries/api.py | 15 +- superset/queries/saved_queries/filters.py | 19 ++- superset/tags/filters.py | 54 ++++++ superset/views/base_api.py | 25 --- tests/integration_tests/base_tests.py | 16 ++ tests/integration_tests/charts/api_tests.py | 153 +++++++++++++---- .../integration_tests/dashboards/api_tests.py | 159 +++++++++++++++--- tests/integration_tests/fixtures/tags.py | 35 ++++ .../queries/saved_queries/api_tests.py | 121 +++++++++++++ tests/unit_tests/tags/filters_test.py | 85 ++++++++++ 18 files changed, 636 insertions(+), 109 deletions(-) create mode 100644 tests/unit_tests/tags/filters_test.py diff --git a/superset-frontend/src/components/ListView/types.ts b/superset-frontend/src/components/ListView/types.ts index ca3a8b3c70923..d7c7cd5117111 100644 --- a/superset-frontend/src/components/ListView/types.ts +++ b/superset-frontend/src/components/ListView/types.ts @@ -117,7 +117,10 @@ export enum FilterOperator { DatasetIsCertified = 'dataset_is_certified', DashboardHasCreatedBy = 'dashboard_has_created_by', ChartHasCreatedBy = 'chart_has_created_by', - DashboardTags = 'dashboard_tags', - ChartTags = 'chart_tags', - SavedQueryTags = 'saved_query_tags', + DashboardTagByName = 'dashboard_tags', + DashboardTagById = 'dashboard_tag_id', + ChartTagByName = 'chart_tags', + ChartTagById = 'chart_tag_id', + SavedQueryTagByName = 'saved_query_tags', + SavedQueryTagById = 'saved_query_tag_id', } diff --git a/superset-frontend/src/pages/ChartList/index.tsx b/superset-frontend/src/pages/ChartList/index.tsx index 65ec54a40b2cc..6650583534bdf 100644 --- a/superset-frontend/src/pages/ChartList/index.tsx +++ b/superset-frontend/src/pages/ChartList/index.tsx @@ -614,7 +614,7 @@ function ChartList(props: ChartListProps) { key: 'tags', id: 'tags', input: 'select', - operator: FilterOperator.ChartTags, + operator: FilterOperator.ChartTagById, unfilteredLabel: t('All'), fetchSelects: loadTags, }, diff --git a/superset-frontend/src/pages/DashboardList/index.tsx b/superset-frontend/src/pages/DashboardList/index.tsx index aa577749d4770..8ffc51ce2a6f3 100644 --- a/superset-frontend/src/pages/DashboardList/index.tsx +++ b/superset-frontend/src/pages/DashboardList/index.tsx @@ -547,7 +547,7 @@ function DashboardList(props: DashboardListProps) { key: 'tags', id: 'tags', input: 'select', - operator: FilterOperator.DashboardTags, + operator: FilterOperator.DashboardTagById, unfilteredLabel: t('All'), fetchSelects: loadTags, }, diff --git a/superset-frontend/src/pages/SavedQueryList/index.tsx b/superset-frontend/src/pages/SavedQueryList/index.tsx index dd4506185c041..72836d60594da 100644 --- a/superset-frontend/src/pages/SavedQueryList/index.tsx +++ b/superset-frontend/src/pages/SavedQueryList/index.tsx @@ -501,7 +501,7 @@ function SavedQueryList({ id: 'tags', key: 'tags', input: 'select', - operator: FilterOperator.SavedQueryTags, + operator: FilterOperator.SavedQueryTagById, fetchSelects: loadTags, }, ] diff --git a/superset/charts/api.py b/superset/charts/api.py index d32f1f665ae14..d814d0fa02a98 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -39,7 +39,8 @@ ChartFilter, ChartHasCreatedByFilter, ChartOwnedCreatedFavoredByMeFilter, - ChartTagFilter, + ChartTagIdFilter, + ChartTagNameFilter, ) from superset.charts.schemas import ( CHART_SCHEMAS, @@ -238,7 +239,7 @@ def ensure_thumbnails_enabled(self) -> Optional[Response]: ], "slice_name": [ChartAllTextFilter], "created_by": [ChartHasCreatedByFilter, ChartCreatedByMeFilter], - "tags": [ChartTagFilter], + "tags": [ChartTagNameFilter, ChartTagIdFilter], } # Will just affect _info endpoint edit_columns = ["slice_name"] diff --git a/superset/charts/filters.py b/superset/charts/filters.py index a7543ba284a82..f9748dd0ecb69 100644 --- a/superset/charts/filters.py +++ b/superset/charts/filters.py @@ -26,10 +26,11 @@ from superset.connectors.sqla.models import SqlaTable from superset.models.core import FavStar from superset.models.slice import Slice +from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter from superset.utils.core import get_user_id from superset.utils.filters import get_dataset_access_filters from superset.views.base import BaseFilter -from superset.views.base_api import BaseFavoriteFilter, BaseTagFilter +from superset.views.base_api import BaseFavoriteFilter class ChartAllTextFilter(BaseFilter): # pylint: disable=too-few-public-methods @@ -60,9 +61,10 @@ class ChartFavoriteFilter(BaseFavoriteFilter): # pylint: disable=too-few-public model = Slice -class ChartTagFilter(BaseTagFilter): # pylint: disable=too-few-public-methods +class ChartTagNameFilter(BaseTagNameFilter): # pylint: disable=too-few-public-methods """ - Custom filter for the GET list that filters all dashboards that a user has favored + Custom filter for the GET list that filters all charts associated with + a certain tag (by its name). """ arg_name = "chart_tags" @@ -70,6 +72,17 @@ class ChartTagFilter(BaseTagFilter): # pylint: disable=too-few-public-methods model = Slice +class ChartTagIdFilter(BaseTagIdFilter): # pylint: disable=too-few-public-methods + """ + Custom filter for the GET list that filters all charts associated with + a certain tag (by its ID). + """ + + arg_name = "chart_tag_id" + class_name = "slice" + model = Slice + + class ChartCertifiedFilter(BaseFilter): # pylint: disable=too-few-public-methods """ Custom filter for the GET list that filters all certified charts diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 2967fd1abd4bd..716e4c416d0e1 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -60,7 +60,8 @@ DashboardCreatedByMeFilter, DashboardFavoriteFilter, DashboardHasCreatedByFilter, - DashboardTagFilter, + DashboardTagIdFilter, + DashboardTagNameFilter, DashboardTitleOrSlugFilter, FilterRelatedRoles, ) @@ -244,7 +245,7 @@ def ensure_thumbnails_enabled(self) -> Optional[Response]: "dashboard_title": [DashboardTitleOrSlugFilter], "id": [DashboardFavoriteFilter, DashboardCertifiedFilter], "created_by": [DashboardCreatedByMeFilter, DashboardHasCreatedByFilter], - "tags": [DashboardTagFilter], + "tags": [DashboardTagIdFilter, DashboardTagNameFilter], } base_order = ("changed_on", "desc") diff --git a/superset/dashboards/filters.py b/superset/dashboards/filters.py index 0c7878d508626..9a4c496b20b31 100644 --- a/superset/dashboards/filters.py +++ b/superset/dashboards/filters.py @@ -29,10 +29,11 @@ from superset.models.embedded_dashboard import EmbeddedDashboard from superset.models.slice import Slice from superset.security.guest_token import GuestTokenResourceType, GuestUser +from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter from superset.utils.core import get_user_id from superset.utils.filters import get_dataset_access_filters from superset.views.base import BaseFilter -from superset.views.base_api import BaseFavoriteFilter, BaseTagFilter +from superset.views.base_api import BaseFavoriteFilter class DashboardTitleOrSlugFilter(BaseFilter): # pylint: disable=too-few-public-methods @@ -78,9 +79,10 @@ class DashboardFavoriteFilter( # pylint: disable=too-few-public-methods model = Dashboard -class DashboardTagFilter(BaseTagFilter): # pylint: disable=too-few-public-methods +class DashboardTagNameFilter(BaseTagNameFilter): # pylint: disable=too-few-public-methods """ - Custom filter for the GET list that filters all dashboards that a user has favored + Custom filter for the GET list that filters all dashboards associated with + a certain tag (by its name). """ arg_name = "dashboard_tags" @@ -88,6 +90,17 @@ class DashboardTagFilter(BaseTagFilter): # pylint: disable=too-few-public-metho model = Dashboard +class DashboardTagIdFilter(BaseTagIdFilter): # pylint: disable=too-few-public-methods + """ + Custom filter for the GET list that filters all dashboards associated with + a certain tag (by its ID). + """ + + arg_name = "dashboard_tag_id" + class_name = "Dashboard" + model = Dashboard + + class DashboardAccessFilter(BaseFilter): # pylint: disable=too-few-public-methods """ List dashboards with the following criteria: diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index cd7b04193ff86..4e34a75039f12 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -25,7 +25,6 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext -from superset import is_feature_enabled from superset.commands.importers.exceptions import ( IncorrectFormatError, NoValidFilesFoundError, @@ -46,7 +45,8 @@ SavedQueryAllTextFilter, SavedQueryFavoriteFilter, SavedQueryFilter, - SavedQueryTagFilter, + SavedQueryTagIdFilter, + SavedQueryTagNameFilter, ) from superset.queries.saved_queries.schemas import ( get_delete_ids_schema, @@ -124,9 +124,10 @@ class SavedQueryRestApi(BaseSupersetModelRestApi): "schema", "sql", "sql_tables", + "tags.id", + "tags.name", + "tags.type", ] - if is_feature_enabled("TAGGING_SYSTEM"): - list_columns += ["tags.id", "tags.name", "tags.type"] list_select_columns = list_columns + ["changed_by_fk", "changed_on"] add_columns = [ "db_id", @@ -161,15 +162,13 @@ class SavedQueryRestApi(BaseSupersetModelRestApi): "schema", "created_by", "changed_by", + "tags", ] - if is_feature_enabled("TAGGING_SYSTEM"): - search_columns += ["tags"] search_filters = { "id": [SavedQueryFavoriteFilter], "label": [SavedQueryAllTextFilter], + "tags": [SavedQueryTagNameFilter, SavedQueryTagIdFilter], } - if is_feature_enabled("TAGGING_SYSTEM"): - search_filters["tags"] = [SavedQueryTagFilter] apispec_parameter_schemas = { "get_delete_ids_schema": get_delete_ids_schema, diff --git a/superset/queries/saved_queries/filters.py b/superset/queries/saved_queries/filters.py index 90e356163fde2..821f42d6f1120 100644 --- a/superset/queries/saved_queries/filters.py +++ b/superset/queries/saved_queries/filters.py @@ -23,8 +23,9 @@ from sqlalchemy.orm.query import Query from superset.models.sql_lab import SavedQuery +from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter from superset.views.base import BaseFilter -from superset.views.base_api import BaseFavoriteFilter, BaseTagFilter +from superset.views.base_api import BaseFavoriteFilter class SavedQueryAllTextFilter(BaseFilter): # pylint: disable=too-few-public-methods @@ -56,9 +57,10 @@ class SavedQueryFavoriteFilter(BaseFavoriteFilter): # pylint: disable=too-few-p model = SavedQuery -class SavedQueryTagFilter(BaseTagFilter): # pylint: disable=too-few-public-methods +class SavedQueryTagNameFilter(BaseTagNameFilter): # pylint: disable=too-few-public-methods """ - Custom filter for the GET list that filters all dashboards that a user has favored + Custom filter for the GET list that filters all saved queries associated with + a certain tag (by its name). """ arg_name = "saved_query_tags" @@ -66,6 +68,17 @@ class SavedQueryTagFilter(BaseTagFilter): # pylint: disable=too-few-public-meth model = SavedQuery +class SavedQueryTagIdFilter(BaseTagIdFilter): # pylint: disable=too-few-public-methods + """ + Custom filter for the GET list that filters all saved queries associated with + a certain tag (by its ID). + """ + + arg_name = "saved_query_tag_id" + class_name = "query" + model = SavedQuery + + class SavedQueryFilter(BaseFilter): # pylint: disable=too-few-public-methods def apply(self, query: BaseQuery, value: Any) -> BaseQuery: """ diff --git a/superset/tags/filters.py b/superset/tags/filters.py index ff6be712d3368..81df9fd7b931b 100644 --- a/superset/tags/filters.py +++ b/superset/tags/filters.py @@ -14,9 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +from typing import Any + from flask_babel import lazy_gettext as _ from sqlalchemy.orm import Query +from superset.connectors.sqla.models import SqlaTable +from superset.extensions import db +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.sql_lab import Query as SqllabQuery from superset.tags.models import Tag, TagType from superset.views.base import BaseFilter @@ -37,3 +46,48 @@ def apply(self, query: Query, value: bool) -> Query: if value is False: return query.filter(Tag.type != TagType.custom) return query + + +class BaseTagNameFilter(BaseFilter): # pylint: disable=too-few-public-methods + """ + Base Custom filter for the GET list that filters all dashboards, slices + and saved queries associated with a tag (by the tag name). + """ + + name = _("Is tagged") + arg_name = "" + class_name = "" + """ The Tag class_name to user """ + model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard + """ The SQLAlchemy model """ + + def apply(self, query: Query, value: Any) -> Query: + ilike_value = f"%{value}%" + tags_query = ( + db.session.query(self.model.id) + .join(self.model.tags) + .filter(Tag.name.ilike(ilike_value)) + ) + return query.filter(self.model.id.in_(tags_query)) + + +class BaseTagIdFilter(BaseFilter): # pylint: disable=too-few-public-methods + """ + Base Custom filter for the GET list that filters all dashboards, slices + and saved queries associated with a tag (by the tag ID). + """ + + name = _("Is tagged") + arg_name = "" + class_name = "" + """ The Tag class_name to user """ + model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard + """ The SQLAlchemy model """ + + def apply(self, query: Query, value: Any) -> Query: + tags_query = ( + db.session.query(self.model.id) + .join(self.model.tags) + .filter(Tag.id == value) + ) + return query.filter(self.model.id.in_(tags_query)) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 5c71147517664..8240481adaa61 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -31,7 +31,6 @@ from sqlalchemy import and_, distinct, func from sqlalchemy.orm.query import Query -from superset.connectors.sqla.models import SqlaTable from superset.exceptions import InvalidPayloadFormatError from superset.extensions import db, event_logger, security_manager, stats_logger_manager from superset.models.core import FavStar @@ -40,7 +39,6 @@ from superset.schemas import error_payload_content from superset.sql_lab import Query as SqllabQuery from superset.superset_typing import FlaskResponse -from superset.tags.models import Tag from superset.utils.core import get_user_id, time_function from superset.views.error_handling import handle_api_exception @@ -168,29 +166,6 @@ def apply(self, query: Query, value: Any) -> Query: return query.filter(and_(~self.model.id.in_(users_favorite_query))) -class BaseTagFilter(BaseFilter): # pylint: disable=too-few-public-methods - """ - Base Custom filter for the GET list that filters all dashboards, slices - that a user has favored or not - """ - - name = _("Is tagged") - arg_name = "" - class_name = "" - """ The Tag class_name to user """ - model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard - """ The SQLAlchemy model """ - - def apply(self, query: Query, value: Any) -> Query: - ilike_value = f"%{value}%" - tags_query = ( - db.session.query(self.model.id) - .join(self.model.tags) - .filter(Tag.name.ilike(ilike_value)) - ) - return query.filter(self.model.id.in_(tags_query)) - - class BaseSupersetApiMixin: csrf_exempt = False diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 0e407b86573d4..b3a000c60136a 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -24,6 +24,7 @@ from unittest.mock import Mock, patch, MagicMock import pandas as pd +import prison from flask import Response, g from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase @@ -33,6 +34,7 @@ from sqlalchemy.sql import func from sqlalchemy.dialects.mysql import dialect +from tests.integration_tests.constants import ADMIN_USERNAME from tests.integration_tests.test_app import app, login from superset.sql_parse import CtasMethod from superset import db, security_manager @@ -589,6 +591,20 @@ def insert_dashboard( db.session.commit() return dashboard + def get_list( + self, + asset_type: str, + filter: dict[str, Any] = {}, + username: str = ADMIN_USERNAME, + ) -> Response: + """ + Get list of assets, by default using admin account. Can be filtered. + """ + self.login(username) + uri = f"api/v1/{asset_type}/?q={prison.dumps(filter)}" + response = self.get_assert_metric(uri, "get_list") + return response + @contextmanager def db_insert_temp_object(obj: DeclarativeMeta): diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index a9af7c12b3994..0f5948ad7b238 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -31,7 +31,7 @@ from superset.commands.chart.data.get_data_command import ChartDataCommand from superset.commands.chart.exceptions import ChartDataQueryFailedError from superset.connectors.sqla.models import SqlaTable -from superset.extensions import cache_manager, db, security_manager # noqa: F401 +from superset.extensions import cache_manager, db, security_manager from superset.models.core import Database, FavStar, FavStarClassName from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -39,11 +39,8 @@ from superset.tags.models import ObjectType, Tag, TaggedObject, TagType from superset.utils import json from superset.utils.core import get_example_default_schema -from superset.utils.database import get_example_database # noqa: F401 -from superset.viz import viz_types # noqa: F401 from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.conftest import with_feature_flags # noqa: F401 from tests.integration_tests.constants import ( ADMIN_USERNAME, ALPHA_USERNAME, @@ -64,6 +61,10 @@ dataset_config, dataset_metadata_config, ) +from tests.integration_tests.fixtures.tags import ( + create_custom_tags, # noqa: F401 + get_filter_params, +) from tests.integration_tests.fixtures.unicode_dashboard import ( load_unicode_dashboard_with_slice, # noqa: F401 load_unicode_data, # noqa: F401 @@ -200,27 +201,8 @@ def add_dashboard_to_chart(self): db.session.delete(self.chart) db.session.commit() - @pytest.fixture() - def create_custom_tags(self): - with self.create_app().app_context(): - tags: list[Tag] = [] - for tag_name in {"one_tag", "new_tag"}: - tag = Tag( - name=tag_name, - type="custom", - ) - db.session.add(tag) - db.session.commit() - tags.append(tag) - - yield tags - - for tags in tags: - db.session.delete(tags) - db.session.commit() - - @pytest.fixture() - def create_chart_with_tag(self, create_custom_tags): + @pytest.fixture + def create_chart_with_tag(self, create_custom_tags): # noqa: F811 with self.create_app().app_context(): alpha_user = self.get_user(ALPHA_USERNAME) @@ -230,7 +212,7 @@ def create_chart_with_tag(self, create_custom_tags): 1, ) - tag = db.session.query(Tag).filter(Tag.name == "one_tag").first() + tag = db.session.query(Tag).filter(Tag.name == "first_tag").first() tag_association = TaggedObject( object_id=chart.id, object_type=ObjectType.chart, @@ -247,6 +229,70 @@ def create_chart_with_tag(self, create_custom_tags): db.session.delete(chart) db.session.commit() + @pytest.fixture + def create_charts_some_with_tags(self, create_custom_tags): # noqa: F811 + """ + Fixture that creates 4 charts: + - ``first_chart`` is associated with ``first_tag`` + - ``second_chart`` is associated with ``second_tag`` + - ``third_chart`` is associated with both ``first_tag`` and ``second_tag`` + - ``fourth_chart`` is not associated with any tag + + Relies on the ``create_custom_tags`` fixture for the tag creation. + """ + with self.create_app().app_context(): + admin_user = self.get_user(ADMIN_USERNAME) + + tags = { + "first_tag": db.session.query(Tag) + .filter(Tag.name == "first_tag") + .first(), + "second_tag": db.session.query(Tag) + .filter(Tag.name == "second_tag") + .first(), + } + + chart_names = ["first_chart", "second_chart", "third_chart", "fourth_chart"] + charts = [ + self.insert_chart(name, [admin_user.id], 1) for name in chart_names + ] + + tag_associations = [ + TaggedObject( + object_id=charts[0].id, + object_type=ObjectType.chart, + tag=tags["first_tag"], + ), + TaggedObject( + object_id=charts[1].id, + object_type=ObjectType.chart, + tag=tags["second_tag"], + ), + TaggedObject( + object_id=charts[2].id, + object_type=ObjectType.chart, + tag=tags["first_tag"], + ), + TaggedObject( + object_id=charts[2].id, + object_type=ObjectType.chart, + tag=tags["second_tag"], + ), + ] + + for association in tag_associations: + db.session.add(association) + db.session.commit() + + yield charts + + # rollback changes + for association in tag_associations: + db.session.delete(association) + for chart in charts: + db.session.delete(chart) + db.session.commit() + def test_info_security_chart(self): """ Chart API: Test info security @@ -1131,6 +1177,55 @@ def test_get_charts_dashboard_filter(self): assert len(result) == 1 assert result[0]["slice_name"] == self.chart.slice_name + @pytest.mark.usefixtures("create_charts_some_with_tags") + def test_get_charts_tag_filters(self): + """ + Chart API: Test get charts with tag filters + """ + # Get custom tags relationship + tags = { + "first_tag": db.session.query(Tag).filter(Tag.name == "first_tag").first(), + "second_tag": db.session.query(Tag) + .filter(Tag.name == "second_tag") + .first(), + "third_tag": db.session.query(Tag).filter(Tag.name == "third_tag").first(), + } + chart_tag_relationship = { + tag.name: db.session.query(Slice.id) + .join(Slice.tags) + .filter(Tag.id == tag.id) + .all() + for tag in tags.values() + } + + # Validate API results for each tag + for tag_name, tag in tags.items(): + expected_charts = chart_tag_relationship[tag_name] + + # Filter by tag ID + filter_params = get_filter_params("chart_tag_id", tag.id) + response_by_id = self.get_list("chart", filter_params) + self.assertEqual(response_by_id.status_code, 200) + data_by_id = json.loads(response_by_id.data.decode("utf-8")) + + # Filter by tag name + filter_params = get_filter_params("chart_tags", tag.name) + response_by_name = self.get_list("chart", filter_params) + self.assertEqual(response_by_name.status_code, 200) + data_by_name = json.loads(response_by_name.data.decode("utf-8")) + + # Compare results + self.assertEqual( + data_by_id["count"], + data_by_name["count"], + len(expected_charts), + ) + self.assertEqual( + set(chart["id"] for chart in data_by_id["result"]), + set(chart["id"] for chart in data_by_name["result"]), + set(chart.id for chart in expected_charts), + ) + def test_get_charts_changed_on(self): """ Dashboard API: Test get charts changed on @@ -2059,7 +2154,7 @@ def test_update_chart_add_tags_can_write_on_tag(self): chart = ( db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() ) - new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one() # get existing tag and add a new one new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom] @@ -2118,7 +2213,7 @@ def test_update_chart_add_tags_can_tag_on_chart(self): chart = ( db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() ) - new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one() # get existing tag and add a new one new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom] @@ -2183,7 +2278,7 @@ def test_update_chart_add_tags_missing_permission(self): chart = ( db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() ) - new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one() # get existing tag and add a new one new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom] diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 99a784e95f376..b4e2958ccd07c 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -30,7 +30,7 @@ from freezegun import freeze_time from sqlalchemy import and_ -from superset import app, db, security_manager # noqa: F401 +from superset import db, security_manager # noqa: F401 from superset.models.dashboard import Dashboard from superset.models.core import FavStar, FavStarClassName from superset.reports.models import ReportSchedule, ReportScheduleType @@ -41,7 +41,6 @@ from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.conftest import with_feature_flags # noqa: F401 from tests.integration_tests.constants import ( ADMIN_USERNAME, ALPHA_USERNAME, @@ -56,6 +55,10 @@ dataset_config, dataset_metadata_config, ) +from tests.integration_tests.fixtures.tags import ( + create_custom_tags, # noqa: F401 + get_filter_params, +) from tests.integration_tests.utils.get_dashboards import get_dashboards_ids from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 @@ -169,27 +172,8 @@ def create_dashboard_with_report(self): db.session.delete(dashboard) db.session.commit() - @pytest.fixture() - def create_custom_tags(self): - with self.create_app().app_context(): - tags: list[Tag] = [] - for tag_name in {"one_tag", "new_tag"}: - tag = Tag( - name=tag_name, - type="custom", - ) - db.session.add(tag) - db.session.commit() - tags.append(tag) - - yield tags - - for tags in tags: - db.session.delete(tags) - db.session.commit() - - @pytest.fixture() - def create_dashboard_with_tag(self, create_custom_tags): + @pytest.fixture + def create_dashboard_with_tag(self, create_custom_tags): # noqa: F811 with self.create_app().app_context(): gamma = self.get_user("gamma") @@ -198,7 +182,7 @@ def create_dashboard_with_tag(self, create_custom_tags): None, [gamma.id], ) - tag = db.session.query(Tag).filter(Tag.name == "one_tag").first() + tag = db.session.query(Tag).filter(Tag.name == "first_tag").first() tag_association = TaggedObject( object_id=dashboard.id, object_type=ObjectType.dashboard, @@ -215,6 +199,76 @@ def create_dashboard_with_tag(self, create_custom_tags): db.session.delete(dashboard) db.session.commit() + @pytest.fixture + def create_dashboards_some_with_tags(self, create_custom_tags): # noqa: F811 + """ + Fixture that creates 4 dashboards: + - ``first_dashboard`` is associated with ``first_tag`` + - ``second_dashboard`` is associated with ``second_tag`` + - ``third_dashboard`` is associated with both ``first_tag`` and ``second_tag`` + - ``fourth_dashboard`` is not associated with any tag + + Relies on the ``create_custom_tags`` fixture for the tag creation. + """ + with self.create_app().app_context(): + admin_user = self.get_user(ADMIN_USERNAME) + + tags = { + "first_tag": db.session.query(Tag) + .filter(Tag.name == "first_tag") + .first(), + "second_tag": db.session.query(Tag) + .filter(Tag.name == "second_tag") + .first(), + } + + dashboard_names = [ + "first_dashboard", + "second_dashboard", + "third_dashboard", + "fourth_dashboard", + ] + dashboards = [ + self.insert_dashboard(name, None, [admin_user.id]) + for name in dashboard_names + ] + + tag_associations = [ + TaggedObject( + object_id=dashboards[0].id, + object_type=ObjectType.chart, + tag=tags["first_tag"], + ), + TaggedObject( + object_id=dashboards[1].id, + object_type=ObjectType.chart, + tag=tags["second_tag"], + ), + TaggedObject( + object_id=dashboards[2].id, + object_type=ObjectType.chart, + tag=tags["first_tag"], + ), + TaggedObject( + object_id=dashboards[2].id, + object_type=ObjectType.chart, + tag=tags["second_tag"], + ), + ] + + for association in tag_associations: + db.session.add(association) + db.session.commit() + + yield dashboards + + # rollback changes + for association in tag_associations: + db.session.delete(association) + for chart in dashboards: + db.session.delete(chart) + db.session.commit() + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_get_dashboard_datasets(self): self.login(ADMIN_USERNAME) @@ -710,6 +764,55 @@ def test_get_dashboards_favorite_filter(self): expected_model.dashboard_title == data["result"][i]["dashboard_title"] ) + @pytest.mark.usefixtures("create_dashboards_some_with_tags") + def test_get_dashboards_tag_filters(self): + """ + Dashboard API: Test get dashboards with tag filters + """ + # Get custom tags relationship + tags = { + "first_tag": db.session.query(Tag).filter(Tag.name == "first_tag").first(), + "second_tag": db.session.query(Tag) + .filter(Tag.name == "second_tag") + .first(), + "third_tag": db.session.query(Tag).filter(Tag.name == "third_tag").first(), + } + dashboard_tag_relationship = { + tag.name: db.session.query(Dashboard.id) + .join(Dashboard.tags) + .filter(Tag.id == tag.id) + .all() + for tag in tags.values() + } + + # Validate API results for each tag + for tag_name, tag in tags.items(): + expected_dashboards = dashboard_tag_relationship[tag_name] + + # Filter by tag ID + filter_params = get_filter_params("dashboard_tag_id", tag.id) + response_by_id = self.get_list("dashboard", filter_params) + self.assertEqual(response_by_id.status_code, 200) + data_by_id = json.loads(response_by_id.data.decode("utf-8")) + + # Filter by tag name + filter_params = get_filter_params("dashboard_tags", tag.name) + response_by_name = self.get_list("dashboard", filter_params) + self.assertEqual(response_by_name.status_code, 200) + data_by_name = json.loads(response_by_name.data.decode("utf-8")) + + # Compare results + self.assertEqual( + data_by_id["count"], + data_by_name["count"], + len(expected_dashboards), + ) + self.assertEqual( + set(chart["id"] for chart in data_by_id["result"]), + set(chart["id"] for chart in data_by_name["result"]), + set(chart.id for chart in expected_dashboards), + ) + @pytest.mark.usefixtures("create_dashboards") def test_get_current_user_favorite_status(self): """ @@ -2504,7 +2607,7 @@ def test_update_dashboard_add_tags_can_write_on_tag(self): .filter(Dashboard.dashboard_title == "dash with tag") .first() ) - new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one() # get existing tag and add a new one new_tags = [tag.id for tag in dashboard.tags if tag.type == TagType.custom] @@ -2566,7 +2669,7 @@ def test_update_dashboard_add_tags_can_tag_on_dashboard(self): .filter(Dashboard.dashboard_title == "dash with tag") .first() ) - new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one() # get existing tag and add a new one new_tags = [tag.id for tag in dashboard.tags if tag.type == TagType.custom] @@ -2580,7 +2683,7 @@ def test_update_dashboard_add_tags_can_tag_on_dashboard(self): # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(tag_list, new_tags) + self.assertEqual(sorted(tag_list), sorted(new_tags)) security_manager.add_permission_role(gamma_role, write_tags_perm) @@ -2635,7 +2738,7 @@ def test_update_dashboard_add_tags_missing_permission(self): .filter(Dashboard.dashboard_title == "dash with tag") .first() ) - new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one() # get existing tag and add a new one new_tags = [tag.id for tag in dashboard.tags if tag.type == TagType.custom] diff --git a/tests/integration_tests/fixtures/tags.py b/tests/integration_tests/fixtures/tags.py index 493b3295d8597..90449957fa927 100644 --- a/tests/integration_tests/fixtures/tags.py +++ b/tests/integration_tests/fixtures/tags.py @@ -17,7 +17,9 @@ import pytest +from superset import db from superset.tags.core import clear_sqla_event_listeners, register_sqla_event_listeners +from superset.tags.models import Tag from tests.integration_tests.test_app import app @@ -31,3 +33,36 @@ def with_tagging_system_feature(): yield app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False clear_sqla_event_listeners() + + +@pytest.fixture +def create_custom_tags(): + with app.app_context(): + tags: list[Tag] = [] + for tag_name in {"first_tag", "second_tag", "third_tag"}: + tag = Tag( + name=tag_name, + type="custom", + ) + db.session.add(tag) + db.session.commit() + tags.append(tag) + + yield tags + + for tags in tags: + db.session.delete(tags) + db.session.commit() + + +# Helper function to return filter parameters +def get_filter_params(opr, value): + return { + "filters": [ + { + "col": "tags", + "opr": opr, + "value": value, + } + ] + } diff --git a/tests/integration_tests/queries/saved_queries/api_tests.py b/tests/integration_tests/queries/saved_queries/api_tests.py index da203a7139b23..9b1184b1f73f8 100644 --- a/tests/integration_tests/queries/saved_queries/api_tests.py +++ b/tests/integration_tests/queries/saved_queries/api_tests.py @@ -32,6 +32,7 @@ from superset.models.core import Database from superset.models.core import FavStar from superset.models.sql_lab import SavedQuery +from superset.tags.models import ObjectType, Tag, TaggedObject from superset.utils.database import get_example_database from superset.utils import json @@ -42,6 +43,10 @@ saved_queries_config, saved_queries_metadata_config, ) +from tests.integration_tests.fixtures.tags import ( + create_custom_tags, # noqa: F401 + get_filter_params, +) SAVED_QUERIES_FIXTURE_COUNT = 10 @@ -123,6 +128,73 @@ def create_saved_queries(self): db.session.delete(fav_saved_query) db.session.commit() + @pytest.fixture + def create_saved_queries_some_with_tags(self, create_custom_tags): # noqa: F811 + """ + Fixture that creates 4 saved queries: + - ``first_query`` is associated with ``first_tag`` + - ``second_query`` is associated with ``second_tag`` + - ``third_query`` is associated with both ``first_tag`` and ``second_tag`` + - ``fourth_query`` is not associated with any tag + + Relies on the ``create_custom_tags`` fixture for the tag creation. + """ + with self.create_app().app_context(): + tags = { + "first_tag": db.session.query(Tag) + .filter(Tag.name == "first_tag") + .first(), + "second_tag": db.session.query(Tag) + .filter(Tag.name == "second_tag") + .first(), + } + + query_labels = [ + "first_query", + "second_query", + "third_query", + "fourth_query", + ] + queries = [ + self.insert_default_saved_query(label=name) for name in query_labels + ] + + tag_associations = [ + TaggedObject( + object_id=queries[0].id, + object_type=ObjectType.chart, + tag=tags["first_tag"], + ), + TaggedObject( + object_id=queries[1].id, + object_type=ObjectType.chart, + tag=tags["second_tag"], + ), + TaggedObject( + object_id=queries[2].id, + object_type=ObjectType.chart, + tag=tags["first_tag"], + ), + TaggedObject( + object_id=queries[2].id, + object_type=ObjectType.chart, + tag=tags["second_tag"], + ), + ] + + for association in tag_associations: + db.session.add(association) + db.session.commit() + + yield queries + + # rollback changes + for association in tag_associations: + db.session.delete(association) + for chart in queries: + db.session.delete(chart) + db.session.commit() + @pytest.mark.usefixtures("create_saved_queries") def test_get_list_saved_query(self): """ @@ -366,6 +438,55 @@ def test_get_list_custom_filter_description_saved_query(self): data = json.loads(rv.data.decode("utf-8")) assert data["count"] == len(all_queries) + @pytest.mark.usefixtures("create_saved_queries_some_with_tags") + def test_get_saved_queries_tag_filters(self): + """ + Saved Query API: Test get saved queries with tag filters + """ + # Get custom tags relationship + tags = { + "first_tag": db.session.query(Tag).filter(Tag.name == "first_tag").first(), + "second_tag": db.session.query(Tag) + .filter(Tag.name == "second_tag") + .first(), + "third_tag": db.session.query(Tag).filter(Tag.name == "third_tag").first(), + } + saved_queries_tag_relationship = { + tag.name: db.session.query(SavedQuery.id) + .join(SavedQuery.tags) + .filter(Tag.id == tag.id) + .all() + for tag in tags.values() + } + + # Validate API results for each tag + for tag_name, tag in tags.items(): + expected_saved_queries = saved_queries_tag_relationship[tag_name] + + # Filter by tag ID + filter_params = get_filter_params("saved_query_tag_id", tag.id) + response_by_id = self.get_list("saved_query", filter_params) + self.assertEqual(response_by_id.status_code, 200) + data_by_id = json.loads(response_by_id.data.decode("utf-8")) + + # Filter by tag name + filter_params = get_filter_params("saved_query_tags", tag.name) + response_by_name = self.get_list("saved_query", filter_params) + self.assertEqual(response_by_name.status_code, 200) + data_by_name = json.loads(response_by_name.data.decode("utf-8")) + + # Compare results + self.assertEqual( + data_by_id["count"], + data_by_name["count"], + len(expected_saved_queries), + ) + self.assertEqual( + set(query["id"] for query in data_by_id["result"]), + set(query["id"] for query in data_by_name["result"]), + set(query.id for query in expected_saved_queries), + ) + @pytest.mark.usefixtures("create_saved_queries") def test_get_saved_query_favorite_filter(self): """ diff --git a/tests/unit_tests/tags/filters_test.py b/tests/unit_tests/tags/filters_test.py new file mode 100644 index 0000000000000..fadf6216d3276 --- /dev/null +++ b/tests/unit_tests/tags/filters_test.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from flask_appbuilder import Model +from flask_appbuilder.models.sqla.interface import SQLAInterface +from sqlalchemy.orm.session import Session + +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.models.sql_lab import SavedQuery +from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter + +FILTER_MODELS = [Slice, Dashboard, SavedQuery] +OBJECT_TYPES = { + "dashboards": "dashboard", + "slices": "chart", + "saved_query": "query", +} + + +@pytest.mark.parametrize("model", FILTER_MODELS) +@pytest.mark.parametrize("name", ["my_tag", "test tag", "blaah"]) +def test_base_tag_filter_by_name(session: Session, model: Model, name: str) -> None: + table = model.__tablename__ + engine = session.get_bind() + query = session.query(model) + filter = BaseTagNameFilter("tags", SQLAInterface(model)) + final_query = filter.apply(query, name) + compiled_query = final_query.statement.compile( + engine, + compile_kwargs={"literal_binds": True}, + ) + + # Assert the JOIN clause is correct + assert ( + f"FROM {table} JOIN tagged_object AS tagged_object_1 ON {table}.id " + "= tagged_object_1.object_id AND tagged_object_1.object_type = " + f"'{OBJECT_TYPES.get(table)}' JOIN tag ON tagged_object_1.tag_id = tag.id" + ) in str(compiled_query) + + # Assert the WHERE clause is correct + assert str(compiled_query).endswith( + f"WHERE lower(tag.name) LIKE lower('%{name}%'))" + ) + + +@pytest.mark.parametrize("model", FILTER_MODELS) +@pytest.mark.parametrize("id", [3, 5, 8]) +def test_base_tag_filter_by_id(session: Session, model: Model, id: int) -> None: + table = model.__tablename__ + engine = session.get_bind() + query = session.query(model) + + filter = BaseTagIdFilter("tags", SQLAInterface(model)) + filter.id_based_filter = True + final_query = filter.apply(query, id) + compiled_query = final_query.statement.compile( + engine, + compile_kwargs={"literal_binds": True}, + ) + + # Assert the JOIN clause is correct + assert ( + f"FROM {table} JOIN tagged_object AS tagged_object_1 ON {table}.id " + "= tagged_object_1.object_id AND tagged_object_1.object_type = " + f"'{OBJECT_TYPES.get(table)}' JOIN tag ON tagged_object_1.tag_id = tag.id" + ) in str(compiled_query) + + # Assert the WHERE clause is correct + assert str(compiled_query).endswith(f"WHERE tag.id = {id})")