From e857ce0393a57e9318e9fe07a9684ce45cdc11ce Mon Sep 17 00:00:00 2001 From: petar-qb Date: Fri, 13 Dec 2024 16:51:57 +0100 Subject: [PATCH 1/7] Enable that components can be nested arbitrarily deep inside basic type structures too --- ...ic_fix_model_manager_get_model_children.md | 48 ++++++++++++ vizro-core/examples/scratch_dev/app.py | 76 ++++++++++++++----- .../src/vizro/managers/_model_manager.py | 19 ++--- 3 files changed, 113 insertions(+), 30 deletions(-) create mode 100644 vizro-core/changelog.d/20241213_161527_petar_pejovic_fix_model_manager_get_model_children.md diff --git a/vizro-core/changelog.d/20241213_161527_petar_pejovic_fix_model_manager_get_model_children.md b/vizro-core/changelog.d/20241213_161527_petar_pejovic_fix_model_manager_get_model_children.md new file mode 100644 index 000000000..0d990d011 --- /dev/null +++ b/vizro-core/changelog.d/20241213_161527_petar_pejovic_fix_model_manager_get_model_children.md @@ -0,0 +1,48 @@ + + + + + + + + +### Fixed + +- Enable that custom components can be nested arbitrarily deep inside basic type structures (e.g. lists within lists), and not just specific attributes. ([#929](https://github.com/mckinsey/vizro/pull/929)) + + + diff --git a/vizro-core/examples/scratch_dev/app.py b/vizro-core/examples/scratch_dev/app.py index 8866948f9..10daa1903 100644 --- a/vizro-core/examples/scratch_dev/app.py +++ b/vizro-core/examples/scratch_dev/app.py @@ -1,31 +1,69 @@ """Dev app to try things out.""" -from vizro import Vizro -import vizro.plotly.express as px -import vizro.models as vm -from vizro.tables import dash_ag_grid import pandas as pd +import vizro.models as vm +import vizro.plotly.express as px +from vizro import Vizro +from typing import List, Literal, Tuple +from vizro.models.types import ControlType +from dash import html -df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/ag-grid/olympic-winners.csv") -columnDefs = [ - {"field": "athlete", "headerName": "The full Name of the athlete"}, - {"field": "age", "headerName": "The number of Years since the athlete was born"}, - {"field": "country", "headerName": "The Country the athlete was born in"}, - {"field": "sport", "headerName": "The Sport the athlete participated in"}, - {"field": "total", "headerName": "The Total number of medals won by the athlete"}, -] -defaultColDef = { - "wrapHeaderText": True, - "autoHeaderHeight": True, -} +class CustomGroup(vm.VizroBaseModel): + """Container to group controls.""" + + type: Literal["custom_group"] = "custom_group" + controls: List[Tuple[str, List[ControlType]]] = [[]] + + def build(self): + return html.Div( + children=[ + html.Div( + children=[html.Br(), html.H5(control_tuple[0]), *[control.build() for control in control_tuple[1]]], + ) + for control_tuple in self.controls + ] + ) + + +vm.Page.add_type("controls", CustomGroup) -# Test app ----------------- page = vm.Page( - title="Page Title", - components=[vm.AgGrid(figure=dash_ag_grid(df, columnDefs=columnDefs, defaultColDef=defaultColDef))], + title="Title", + components=[ + vm.Graph(id="graph_id", figure=px.scatter(px.data.iris(), x="sepal_width", y="sepal_length", color="species")), + ], + controls=[ + CustomGroup( + controls=[ + ( + "Categorical Filters", + [ + vm.Filter(column="species"), + ], + ), + ( + "Numeric Filters", + [ + vm.Filter(column="petal_length"), + vm.Filter(column="sepal_length"), + ], + ), + ], + ), + vm.Parameter( + targets=["graph_id.x"], + selector=vm.RadioItems( + title="Select X Axis", + options=["sepal_width", "sepal_length", "petal_width", "petal_length"], + value="sepal_width", + ), + ), + ], ) + + dashboard = vm.Dashboard(pages=[page]) if __name__ == "__main__": diff --git a/vizro-core/src/vizro/managers/_model_manager.py b/vizro-core/src/vizro/managers/_model_manager.py index fe19a1e11..b605b1e89 100644 --- a/vizro-core/src/vizro/managers/_model_manager.py +++ b/vizro-core/src/vizro/managers/_model_manager.py @@ -84,28 +84,25 @@ def _get_models( def __get_model_children(self, model: Model) -> Generator[Model, None, None]: """Iterates through children of `model`. - Currently looks only through certain fields so might miss some children models. + Currently, this method looks only through certain fields so might miss some children models. """ from vizro.models import VizroBaseModel if isinstance(model, VizroBaseModel): yield model + # We don't handle dicts of models at the moment. See below TO-DOs for how this will all be improved in future. + if isinstance(model, (list, tuple)): + for single_model in model: + yield from self.__get_model_children(single_model) + # TODO: in future this list should not be maintained manually. Instead we should look through all model children - # by looking at model.model_fields. + # by looking at model.model_fields. model_fields = ["components", "tabs", "controls", "actions", "selector"] for model_field in model_fields: if (model_field_value := getattr(model, model_field, None)) is not None: - if isinstance(model_field_value, list): - # For fields like components that are list of models. - for single_model_field_value in model_field_value: - yield from self.__get_model_children(single_model_field_value) - else: - # For fields that have single model like selector. - yield from self.__get_model_children(model_field_value) - # We don't handle dicts of models at the moment. See below TODO for how this will all be improved in - # future. + yield from self.__get_model_children(model_field_value) # TODO: Add navigation, accordions and other page objects. Won't be needed once have made whole model # manager work better recursively and have better ways to navigate the hierarchy. In pydantic v2 this would use From 223cadc55b4ee6e2b194f3d8a1a105cc9dce6996 Mon Sep 17 00:00:00 2001 From: petar-qb Date: Wed, 22 Jan 2025 09:27:07 +0100 Subject: [PATCH 2/7] model_manager unit tests and __get_model_children fix for dict --- .../src/vizro/managers/_model_manager.py | 11 +- vizro-core/tests/unit/vizro/conftest.py | 5 - .../unit/vizro/managers/test_model_manager.py | 208 ++++++++++++++++++ .../vizro/models/_components/test_ag_grid.py | 2 +- 4 files changed, 216 insertions(+), 10 deletions(-) create mode 100644 vizro-core/tests/unit/vizro/managers/test_model_manager.py diff --git a/vizro-core/src/vizro/managers/_model_manager.py b/vizro-core/src/vizro/managers/_model_manager.py index b605b1e89..73228ad52 100644 --- a/vizro-core/src/vizro/managers/_model_manager.py +++ b/vizro-core/src/vizro/managers/_model_manager.py @@ -84,17 +84,20 @@ def _get_models( def __get_model_children(self, model: Model) -> Generator[Model, None, None]: """Iterates through children of `model`. - Currently, this method looks only through certain fields so might miss some children models. + Currently, this method looks only through certain fields (components, tabs, controls, actions, selector) and + their children so might miss some children models. """ from vizro.models import VizroBaseModel if isinstance(model, VizroBaseModel): yield model - - # We don't handle dicts of models at the moment. See below TO-DOs for how this will all be improved in future. - if isinstance(model, (list, tuple)): + elif isinstance(model, list): for single_model in model: yield from self.__get_model_children(single_model) + elif isinstance(model, dict): + # We don't look through keys because Vizro models aren't hashable. + for single_model in model.values(): + yield from self.__get_model_children(single_model) # TODO: in future this list should not be maintained manually. Instead we should look through all model children # by looking at model.model_fields. diff --git a/vizro-core/tests/unit/vizro/conftest.py b/vizro-core/tests/unit/vizro/conftest.py index dd15c9fa4..511647e16 100644 --- a/vizro-core/tests/unit/vizro/conftest.py +++ b/vizro-core/tests/unit/vizro/conftest.py @@ -75,11 +75,6 @@ def ag_grid_with_id(gapminder): return dash_ag_grid(id="underlying_ag_grid_id", data_frame=gapminder) -@pytest.fixture -def ag_grid_with_id_and_conf(gapminder): - return dash_ag_grid(id="underlying_ag_grid_id", data_frame=gapminder) - - @pytest.fixture def standard_dash_table(gapminder): return dash_data_table(data_frame=gapminder) diff --git a/vizro-core/tests/unit/vizro/managers/test_model_manager.py b/vizro-core/tests/unit/vizro/managers/test_model_manager.py new file mode 100644 index 000000000..589de5537 --- /dev/null +++ b/vizro-core/tests/unit/vizro/managers/test_model_manager.py @@ -0,0 +1,208 @@ +from typing import Any + +import pytest + +import vizro.models as vm +from vizro.managers import model_manager +from vizro.managers._model_manager import FIGURE_MODELS + + +@pytest.fixture +def page_1(standard_px_chart): + return vm.Page( + id="page_1_id", + title="Page 1", + components=[vm.Button(id="page_1_button_id"), vm.Graph(id="page_1_graph_id", figure=standard_px_chart)], + controls=[], + ) + + +@pytest.fixture +def managers_dashboard_two_pages(vizro_app, page_1, standard_kpi_card): + return vm.Dashboard( + pages=[ + page_1, + vm.Page( + id="page_2_id", + title="Page 2", + components=[ + vm.Button(id="page_2_button_id"), + vm.Figure(id="page_2_figure_id", figure=standard_kpi_card), + ], + controls=[vm.Filter(id="page_2_filter", column="year")], + ), + ] + ) + + +pytestmark = pytest.mark.usefixtures("managers_dashboard_two_pages") + + +class TestGetModels: + """Test _get_models method.""" + + def test_model_type_none_page_none(self): + """model_type is None | page is None -> return all elements.""" + result = [model.id for model in model_manager._get_models()] + + expected_present = { + "page_1_id", + "page_1_button_id", + "page_1_graph_id", + "page_2_id", + "page_2_button_id", + "page_2_figure_id", + } + + assert expected_present.issubset(result) + + def test_model_type_page_none(self): + """model_type is vm.Button | page is None -> return all vm.Button from the dashboard.""" + result = [model.id for model in model_manager._get_models(model_type=vm.Button)] + + expected_present = {"page_1_button_id", "page_2_button_id"} + expected_absent = {"page_1_id", "page_1_graph_id", "page_2_id", "page_2_figure_id"} + + assert expected_present.issubset(result) + assert expected_absent.isdisjoint(result) + + def test_model_type_none_page_not_none(self, page_1): + """model_type is None | page is page_1 -> return all elements from the page_1.""" + result = [model.id for model in model_manager._get_models(page=page_1)] + + expected_present = {"page_1_id", "page_1_button_id", "page_1_graph_id"} + expected_absent = {"page_2_id", "page_2_button_id", "page_2_figure_id"} + + assert expected_present.issubset(result) + assert expected_absent.isdisjoint(result) + + def test_model_type_not_none_page_not_none(self, page_1): + """model_type is vm.Button | page is page_1 -> return all vm.Button from the page_1.""" + result = [model.id for model in model_manager._get_models(model_type=vm.Button, page=page_1)] + + expected_present = {"page_1_button_id"} + expected_absent = {"page_1_id", "page_1_graph_id", "page_2_id", "page_2_button_id", "page_2_figure_id"} + + assert expected_present.issubset(result) + assert expected_absent.isdisjoint(result) + + def test_model_type_no_match_page_none(self): + """model_type matches no type | page is None -> return empty list.""" + # There is no AgGrid in the dashboard + result = [model.id for model in model_manager._get_models(model_type=vm.AgGrid)] + + assert result == [] + + def test_model_type_no_match_page_not_none(self, page_1): + """model_type matches no type | page is page_1 -> return empty list.""" + # There is no AgGrid in the page_1 + result = [model.id for model in model_manager._get_models(model_type=vm.AgGrid, page=page_1)] + + assert result == [] + + def test_model_type_tuple_of_models(self): + """model_type is tuple of models -> return all elements of the specified types from the dashboard.""" + result = [model.id for model in model_manager._get_models(model_type=(vm.Button, vm.Graph))] + + expected_present = {"page_1_button_id", "page_1_graph_id", "page_2_button_id"} + expected_absent = {"page_1_id", "page_2_id", "page_2_figure_id"} + + assert expected_present.issubset(result) + assert expected_absent.isdisjoint(result) + + def test_model_type_figure_models(self): + """model_type is FIGURE_MODELS | page is None -> return all figure elements from the dashboard.""" + result = [model.id for model in model_manager._get_models(model_type=FIGURE_MODELS)] + + expected_present = {"page_1_graph_id", "page_2_figure_id"} + expected_absent = {"page_1_id", "page_1_button_id", "page_2_id", "page_2_button_id"} + + assert expected_present.issubset(result) + assert expected_absent.isdisjoint(result) + + def test_subclass_model_type(self, page_1, standard_px_chart): + """model_type is subclass of vm.Graph -> return all elements of the specified type and its subclasses.""" + + class CustomGraph(vm.Graph): + pass + + page_1.components.append(CustomGraph(id="page_1_custom_graph_id", figure=standard_px_chart)) + + # Return CustomGraph and don't return Graph + custom_graph_result = [model.id for model in model_manager._get_models(model_type=CustomGraph)] + assert "page_1_custom_graph_id" in custom_graph_result + assert "page_1_graph_id" not in custom_graph_result + + # Return CustomGraph and Graph + vm_graph_result = [model.id for model in model_manager._get_models(model_type=vm.Graph)] + assert "page_1_custom_graph_id" in vm_graph_result + assert "page_1_graph_id" in vm_graph_result + + # This test checks if the model manager can find model_type under a nested model. + @pytest.mark.parametrize( + "controls, expected_id", + [ + # model as a property of a custom model + (vm.Filter(id="page_1_control_0", column="year"), "page_1_control_0"), + # model inside a list + ([vm.Filter(id="page_1_control_1", column="year")], "page_1_control_1"), + # model as a value in a dictionary + ({"key_1": vm.Filter(id="page_1_control_2", column="year")}, "page_1_control_2"), + # model nested within a list of lists + ([[vm.Filter(id="page_1_control_3", column="year")]], "page_1_control_3"), + # model nested in a dictionary of dictionaries + ({"key_1": {"key_2": vm.Filter(id="page_1_control_4", column="year")}}, "page_1_control_4"), + # model nested in a list of dictionaries + ([{"key_1": vm.Filter(id="page_1_control_5", column="year")}], "page_1_control_5"), + # model nested in a dictionary of lists + ({"key_1": [vm.Filter(id="page_1_control_6", column="year")]}, "page_1_control_6"), + ], + ) + def test_nested_models(self, page_1, controls, expected_id): + """Model is nested under another model and known property in different ways -> return the model.""" + + class ControlGroup(vm.VizroBaseModel): + controls: Any + + page_1.controls.append(ControlGroup(controls=controls)) + + result = [model.id for model in model_manager._get_models(model_type=vm.Filter, page=page_1)] + + assert expected_id in result + + def test_model_under_unknown_field(self, page_1): + """Model is nested under another model but under an unknown field -> don't return the model.""" + + class ControlGroup(vm.VizroBaseModel): + unknown_field: Any + + page_1.controls.append(ControlGroup(unknown_field=[vm.Filter(id="page_1_control_1", column="year")])) + + result = [model.id for model in model_manager._get_models(model_type=vm.Filter, page=page_1)] + + assert "page_1_control_1" not in result + + +class TestGetModelPage: + """Test _get_model_page method.""" + + def test_model_in_page(self, page_1): + """Model is in page -> return page.""" + result = model_manager._get_model_page(page_1.components[0]) + + assert result == page_1 + + def test_model_not_in_page(self, page_1): + """Model is not in page -> return None.""" + # Instantiate standalone model + button = vm.Button(id="standalone_button_id") + + result = model_manager._get_model_page(button) + + assert result is None + + def test_model_is_page(self, page_1): + """Model is Page -> return that page.""" + result = model_manager._get_model_page(page_1) + + assert result == page_1 diff --git a/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py b/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py index 4277c585d..213b76006 100644 --- a/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py +++ b/vizro-core/tests/unit/vizro/models/_components/test_ag_grid.py @@ -176,7 +176,7 @@ def test_ag_grid_build_mandatory_only(self, standard_ag_grid, gapminder): @pytest.mark.parametrize( "ag_grid, underlying_id_expected", [ - ("ag_grid_with_id_and_conf", "underlying_ag_grid_id"), + ("ag_grid_with_id", "underlying_ag_grid_id"), ("standard_ag_grid", "__input_text_ag_grid"), ], ) From 7f22b54fb1bf9cc9cb135ddeb5461a6709de3485 Mon Sep 17 00:00:00 2001 From: petar-qb Date: Wed, 22 Jan 2025 10:29:36 +0100 Subject: [PATCH 3/7] changelog update --- ...161527_petar_pejovic_fix_model_manager_get_model_children.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vizro-core/changelog.d/20241213_161527_petar_pejovic_fix_model_manager_get_model_children.md b/vizro-core/changelog.d/20241213_161527_petar_pejovic_fix_model_manager_get_model_children.md index 9f31a0074..9e9215ab4 100644 --- a/vizro-core/changelog.d/20241213_161527_petar_pejovic_fix_model_manager_get_model_children.md +++ b/vizro-core/changelog.d/20241213_161527_petar_pejovic_fix_model_manager_get_model_children.md @@ -37,7 +37,7 @@ Uncomment the section that is right (remove the HTML comment wrapper). ### Fixed -- Enable arbitrary deep nesting of custom components. All models can now be nested arbitrarily deep within lists. ([#929](https://github.com/mckinsey/vizro/pull/929)) +- Enable arbitrary deep nesting of custom components. All models can now be nested arbitrarily deep within lists or dicts. ([#929](https://github.com/mckinsey/vizro/pull/929))