diff --git a/vizro-core/changelog.d/20241128_114150_antony.milne_model_manager.md b/vizro-core/changelog.d/20241128_114150_antony.milne_model_manager.md new file mode 100644 index 000000000..c814e29e5 --- /dev/null +++ b/vizro-core/changelog.d/20241128_114150_antony.milne_model_manager.md @@ -0,0 +1,47 @@ + + + + + + +### Changed + +- Custom controls can be nested arbitrarily deep inside `Page.controls`. ([#903](https://github.com/mckinsey/vizro/pull/903)) + + + + diff --git a/vizro-core/examples/scratch_dev/app.py b/vizro-core/examples/scratch_dev/app.py index 66b23823d..58be02e00 100644 --- a/vizro-core/examples/scratch_dev/app.py +++ b/vizro-core/examples/scratch_dev/app.py @@ -1,248 +1,69 @@ -"""Dev app to try things out.""" +from typing import List, Literal -import time -import yaml - -import dash -import pandas as pd -from flask_caching import Cache +from dash import html import vizro.models as vm import vizro.plotly.express as px from vizro import Vizro -from vizro.managers import data_manager -from functools import partial - -print("INITIALIZING") +from vizro.models.types import ControlType -SPECIES_COLORS = {"setosa": "#00b4ff", "versicolor": "#ff9222", "virginica": "#3949ab"} -BAR_CHART_CONF = dict(x="species", color="species", color_discrete_map=SPECIES_COLORS) -SCATTER_CHART_CONF = dict(x="sepal_length", y="petal_length", color="species", color_discrete_map=SPECIES_COLORS) +df_gapminder = px.data.gapminder() -def load_from_file(filter_column=None, parametrized_species=None): - # Load the full iris dataset - df = px.data.iris() - df["date_column"] = pd.date_range(start=pd.to_datetime("2024-01-01"), periods=len(df), freq="D") +class ControlGroup(vm.VizroBaseModel): + """Container to group controls.""" - with open("data.yaml", "r") as file: - data = { - "setosa": 0, - "versicolor": 0, - "virginica": 0, - "min": 0, - "max": 10, - "date_min": "2024-01-01", - "date_max": "2024-05-29", - } - data.update(yaml.safe_load(file) or {}) + type: Literal["control_group"] = "control_group" + title: str + controls: List[ControlType] = [] - if filter_column == "species": - df = pd.concat( - objs=[ - df[df[filter_column] == "setosa"].head(data["setosa"]), - df[df[filter_column] == "versicolor"].head(data["versicolor"]), - df[df[filter_column] == "virginica"].head(data["virginica"]), - ], - ignore_index=True, + def build(self): + return html.Div( + [html.H4(self.title), html.Hr()] + [control.build() for control in self.controls], ) - elif filter_column == "sepal_length": - df = df[df[filter_column].between(data["min"], data["max"], inclusive="both")] - elif filter_column == "date_column": - date_min = pd.to_datetime(data["date_min"]) - date_max = pd.to_datetime(data["date_max"]) - df = df[df[filter_column].between(date_min, date_max, inclusive="both")] - else: - raise ValueError("Invalid filter_column") - - if parametrized_species: - df = df[df["species"].isin(parametrized_species)] - - return df - - -data_manager["load_from_file_species"] = partial(load_from_file, filter_column="species") -data_manager["load_from_file_sepal_length"] = partial(load_from_file, filter_column="sepal_length") -data_manager["load_from_file_date_column"] = partial(load_from_file, filter_column="date_column") - - -# TODO-DEV: Turn on/off caching to see how it affects the app. -# data_manager.cache = Cache(config={"CACHE_TYPE": "SimpleCache", "CACHE_DEFAULT_TIMEOUT": 10}) - - -homepage = vm.Page( - title="Homepage", - components=[ - vm.Card(text="This is the homepage."), - ], -) - -page_1 = vm.Page( - title="Dynamic vs Static filter", - components=[ - vm.Graph( - id="p1-G-1", - figure=px.bar(data_frame="load_from_file_species", **BAR_CHART_CONF), - ), - vm.Graph( - id="p1-G-2", - figure=px.scatter(data_frame=px.data.iris(), **SCATTER_CHART_CONF), - ), - ], - controls=[ - vm.Filter(id="p1-F-1", column="species", targets=["p1-G-1"], selector=vm.Dropdown(title="Dynamic filter")), - vm.Filter(id="p1-F-2", column="species", targets=["p1-G-2"], selector=vm.Dropdown(title="Static filter")), - vm.Parameter( - targets=["p1-G-1.x", "p1-G-2.x"], - selector=vm.RadioItems(options=["species", "sepal_width"], title="Simple X-axis parameter"), - ), - ], -) - - -page_2 = vm.Page( - title="Categorical dynamic selectors", - components=[ - vm.Graph( - id="p2-G-1", - figure=px.bar(data_frame="load_from_file_species", **BAR_CHART_CONF), - ), - ], - controls=[ - vm.Filter(id="p2-F-1", column="species", selector=vm.Dropdown()), - vm.Filter(id="p2-F-2", column="species", selector=vm.Dropdown(multi=False)), - vm.Filter(id="p2-F-3", column="species", selector=vm.Checklist()), - vm.Filter(id="p2-F-4", column="species", selector=vm.RadioItems()), - vm.Parameter( - targets=["p2-G-1.x"], - selector=vm.RadioItems( - options=["species", "sepal_width"], value="species", title="Simple X-axis parameter" - ), - ), - ], -) -page_3 = vm.Page( - title="Numerical dynamic selectors", - components=[ - vm.Graph( - id="p3-G-1", - figure=px.bar(data_frame="load_from_file_sepal_length", **BAR_CHART_CONF), - ), - ], - controls=[ - vm.Filter(id="p3-F-1", column="sepal_length", selector=vm.Slider()), - vm.Filter(id="p3-F-2", column="sepal_length", selector=vm.RangeSlider()), - vm.Parameter( - targets=["p3-G-1.x"], - selector=vm.RadioItems( - options=["species", "sepal_width"], value="species", title="Simple X-axis parameter" - ), - ), - ], -) - -page_4 = vm.Page( - title="[TO BE DONE IN THE FOLLOW UP PR] Temporal dynamic selectors", - components=[ - vm.Graph( - id="p4-G-1", - figure=px.bar(data_frame="load_from_file_date_column", **BAR_CHART_CONF), - ), - ], - controls=[ - vm.Filter(id="p4-F-1", column="date_column", selector=vm.DatePicker(range=False)), - vm.Filter(id="p4-F-2", column="date_column", selector=vm.DatePicker()), - vm.Parameter( - targets=["p4-G-1.x"], - selector=vm.RadioItems( - options=["species", "sepal_width"], value="species", title="Simple X-axis parameter" - ), - ), - ], -) +vm.Page.add_type("controls", ControlGroup) -page_5 = vm.Page( - title="Parametrised dynamic selectors", +page1 = vm.Page( + title="Relationship Analysis", components=[ - vm.Graph( - id="p5-G-1", - figure=px.bar(data_frame="load_from_file_species", **BAR_CHART_CONF), - ), + vm.Graph(id="scatter", figure=px.scatter(df_gapminder, x="gdpPercap", y="lifeExp", size="pop")), ], controls=[ - vm.Filter(id="p5-F-1", column="species", targets=["p5-G-1"], selector=vm.Checklist()), - vm.Parameter( - targets=[ - "p5-G-1.data_frame.parametrized_species", - # TODO: Uncomment the following target and see the magic :D - # Is this the indicator that parameter.targets prop has to support 'target' definition without the '.'? - # "p5-F-1.", + ControlGroup( + title="Group A", + controls=[ + vm.Parameter( + id="this", + targets=["scatter.x"], + selector=vm.Dropdown( + options=["lifeExp", "gdpPercap", "pop"], multi=False, value="gdpPercap", title="Choose x-axis" + ), + ), + vm.Parameter( + targets=["scatter.y"], + selector=vm.Dropdown( + options=["lifeExp", "gdpPercap", "pop"], multi=False, value="lifeExp", title="Choose y-axis" + ), + ), ], - selector=vm.Dropdown( - options=["setosa", "versicolor", "virginica"], multi=True, title="Parametrized species" - ), ), - vm.Parameter( - targets=[ - "p5-G-1.x", - # TODO: Uncomment the following target and see the magic :D - # "p5-F-1.", + ControlGroup( + title="Group B", + controls=[ + vm.Parameter( + targets=["scatter.size"], + selector=vm.Dropdown( + options=["lifeExp", "gdpPercap", "pop"], multi=False, value="pop", title="Choose bubble size" + ), + ) ], - selector=vm.RadioItems( - options=["species", "sepal_width"], value="species", title="Simple X-axis parameter" - ), ), ], ) - -page_6 = vm.Page( - title="Page to test things out", - components=[ - vm.Graph(id="graph_dynamic", figure=px.bar(data_frame="load_from_file_species", **BAR_CHART_CONF)), - vm.Graph( - id="graph_static", - figure=px.scatter(data_frame=px.data.iris(), **SCATTER_CHART_CONF), - ), - ], - controls=[ - vm.Filter( - id="filter_container_id", - column="species", - targets=["graph_dynamic"], - # targets=["graph_static"], - # selector=vm.Dropdown(id="filter_id"), - # selector=vm.Dropdown(id="filter_id", value=["setosa"]), - # selector=vm.Checklist(id="filter_id"), - # selector=vm.Checklist(id="filter_id", value=["setosa"]), - # TODO-BUG: vm.Dropdown(multi=False) Doesn't work if value is cleared. The persistence storage become - # "null" and our placeholder component dmc.DateRangePicker can't process null value. It expects a value or - # a list of values. - # SOLUTION -> Create the "Universal Vizro placeholder component". - # TEMPORARY SOLUTION -> set clearable=False for the dynamic Dropdown(multi=False) - # selector=vm.Dropdown(id="filter_id", multi=False), - # selector=vm.Dropdown(id="filter_id", multi=False, value="setosa"), - # selector=vm.RadioItems(id="filter_id"), - # selector=vm.RadioItems(id="filter_id", value="setosa"), - # selector=vm.Slider(id="filter_id"), - # selector=vm.Slider(id="filter_id", value=5), - # selector=vm.RangeSlider(id="filter_id"), - # selector=vm.RangeSlider(id="filter_id", value=[5, 7]), - ), - vm.Parameter( - targets=["graph_dynamic.x"], - selector=vm.RadioItems(options=["species", "sepal_width"], title="Simple X-axis parameter"), - ), - ], -) - -dashboard = vm.Dashboard(pages=[homepage, page_1, page_2, page_3, page_4, page_5, page_6]) +dashboard = vm.Dashboard(pages=[page1]) if __name__ == "__main__": - app = Vizro().build(dashboard) - - print("RUNNING\n") - - app.run(dev_tools_hot_reload=False) + Vizro().build(dashboard).run() diff --git a/vizro-core/src/vizro/_vizro.py b/vizro-core/src/vizro/_vizro.py index 81c9a46b7..f648d053f 100644 --- a/vizro-core/src/vizro/_vizro.py +++ b/vizro-core/src/vizro/_vizro.py @@ -5,7 +5,7 @@ from collections.abc import Iterable from contextlib import suppress from pathlib import Path, PurePosixPath -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, TypedDict, cast import dash import plotly.io as pio @@ -145,13 +145,12 @@ def _pre_build(): # Any models that are created during the pre-build process *will not* themselves have pre_build run on them. # In future may add a second pre_build loop after the first one. - # model_manager results is wrapped into a list to avoid RuntimeError: dictionary changed size during iteration - for _, filter_obj in list(model_manager._items_with_type(Filter)): + for filter in cast(Iterable[Filter], model_manager._get_models(Filter)): # Run pre_build on all filters first, then on all other models. This handles dependency between Filter # and Page pre_build and ensures that filters are pre-built before the Page objects that use them. # This is important because the Page pre_build method checks whether filters are dynamic or not, which is # defined in the filter's pre_build method. - filter_obj.pre_build() + filter.pre_build() for model_id in set(model_manager): model = model_manager[model_id] if hasattr(model, "pre_build") and not isinstance(model, Filter): diff --git a/vizro-core/src/vizro/actions/_action_loop/_action_loop.py b/vizro-core/src/vizro/actions/_action_loop/_action_loop.py index d9caac0ff..2d8edefa9 100644 --- a/vizro-core/src/vizro/actions/_action_loop/_action_loop.py +++ b/vizro-core/src/vizro/actions/_action_loop/_action_loop.py @@ -1,10 +1,14 @@ """The action loop creates all the required action callbacks and its components.""" +from collections.abc import Iterable +from typing import cast + from dash import html -from vizro.actions._action_loop._action_loop_utils import _get_actions_on_registered_pages from vizro.actions._action_loop._build_action_loop_callbacks import _build_action_loop_callbacks from vizro.actions._action_loop._get_action_loop_components import _get_action_loop_components +from vizro.managers import model_manager +from vizro.models import Action class ActionLoop: @@ -37,5 +41,8 @@ def _build_actions_models(): List of required components for each `Action` in the `Dashboard` e.g. list[dcc.Download] """ - actions = _get_actions_on_registered_pages() - return html.Div([action.build() for action in actions], id="app_action_models_components_div", hidden=True) + return html.Div( + [action.build() for action in cast(Iterable[Action], model_manager._get_models(Action))], + id="app_action_models_components_div", + hidden=True, + ) diff --git a/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py b/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py deleted file mode 100644 index d640f24bb..000000000 --- a/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Contains utilities to extract the Action and ActionsChain models from registered pages only.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import dash - -from vizro.managers import model_manager -from vizro.managers._model_manager import ModelID - -if TYPE_CHECKING: - from vizro.models import Action, Page - from vizro.models._action._actions_chain import ActionsChain - - -def _get_actions_chains_on_all_pages() -> list[ActionsChain]: - """Gets list of ActionsChain models for registered pages.""" - actions_chains: list[ActionsChain] = [] - # TODO: once dash.page_registry matches up with model_manager, change this to use purely model_manager. - # Making the change now leads to problems since there can be Action models defined that aren't used in the - # dashboard. - # See https://github.com/mckinsey/vizro/pull/366. - for registered_page in dash.page_registry.values(): - try: - page: Page = model_manager[registered_page["module"]] - except KeyError: - continue - actions_chains.extend(model_manager._get_page_actions_chains(page_id=ModelID(str(page.id)))) - return actions_chains - - -def _get_actions_on_registered_pages() -> list[Action]: - """Gets list of Action models for registered pages.""" - return [action for action_chain in _get_actions_chains_on_all_pages() for action in action_chain.actions] diff --git a/vizro-core/src/vizro/actions/_action_loop/_build_action_loop_callbacks.py b/vizro-core/src/vizro/actions/_action_loop/_build_action_loop_callbacks.py index bc76a146d..439c443e1 100644 --- a/vizro-core/src/vizro/actions/_action_loop/_build_action_loop_callbacks.py +++ b/vizro-core/src/vizro/actions/_action_loop/_build_action_loop_callbacks.py @@ -4,20 +4,20 @@ from dash import ClientsideFunction, Input, Output, State, clientside_callback -from vizro.actions._action_loop._action_loop_utils import ( - _get_actions_chains_on_all_pages, - _get_actions_on_registered_pages, -) from vizro.managers import model_manager from vizro.managers._model_manager import ModelID +from vizro.models import Action +from vizro.models._action._actions_chain import ActionsChain logger = logging.getLogger(__name__) def _build_action_loop_callbacks() -> None: """Creates all required dash callbacks for the action loop.""" - actions_chains = _get_actions_chains_on_all_pages() - actions = _get_actions_on_registered_pages() + # actions_chain and actions are not iterated over multiple times so conversion to list is not technically needed, + # but it prevents future bugs and matches _get_action_loop_components. + actions_chains: list[ActionsChain] = list(model_manager._get_models(ActionsChain)) + actions: list[Action] = list(model_manager._get_models(Action)) if not actions_chains: return diff --git a/vizro-core/src/vizro/actions/_action_loop/_get_action_loop_components.py b/vizro-core/src/vizro/actions/_action_loop/_get_action_loop_components.py index 7d34c2a4a..2d18c18df 100644 --- a/vizro-core/src/vizro/actions/_action_loop/_get_action_loop_components.py +++ b/vizro-core/src/vizro/actions/_action_loop/_get_action_loop_components.py @@ -2,10 +2,9 @@ from dash import dcc, html -from vizro.actions._action_loop._action_loop_utils import ( - _get_actions_chains_on_all_pages, - _get_actions_on_registered_pages, -) +from vizro.managers import model_manager +from vizro.models import Action +from vizro.models._action._actions_chain import ActionsChain def _get_action_loop_components() -> html.Div: @@ -15,8 +14,9 @@ def _get_action_loop_components() -> html.Div: List of dcc or html components. """ - actions_chains = _get_actions_chains_on_all_pages() - actions = _get_actions_on_registered_pages() + # actions_chain and actions are iterated over multiple times so must be realized into a list. + actions_chains: list[ActionsChain] = list(model_manager._get_models(ActionsChain)) + actions: list[Action] = list(model_manager._get_models(Action)) if not actions_chains: return html.Div(id="action_loop_components_div") diff --git a/vizro-core/src/vizro/actions/_actions_utils.py b/vizro-core/src/vizro/actions/_actions_utils.py index d46803bd3..1873b6620 100644 --- a/vizro-core/src/vizro/actions/_actions_utils.py +++ b/vizro-core/src/vizro/actions/_actions_utils.py @@ -2,8 +2,9 @@ from __future__ import annotations +from collections.abc import Iterable from copy import deepcopy -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast import pandas as pd @@ -80,15 +81,12 @@ def _apply_filter_controls( return data_frame -def _get_parent_vizro_model(_underlying_callable_object_id: str) -> VizroBaseModel: +def _get_parent_model(_underlying_callable_object_id: str) -> VizroBaseModel: from vizro.models import VizroBaseModel - for _, vizro_base_model in model_manager._items_with_type(VizroBaseModel): - if ( - hasattr(vizro_base_model, "_input_component_id") - and vizro_base_model._input_component_id == _underlying_callable_object_id - ): - return vizro_base_model + for model in cast(Iterable[VizroBaseModel], model_manager._get_models()): + if hasattr(model, "_input_component_id") and model._input_component_id == _underlying_callable_object_id: + return model raise KeyError( f"No parent Vizro model found for underlying callable object with id: {_underlying_callable_object_id}." ) diff --git a/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py b/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py index 4bf82991c..ba18e4fa5 100644 --- a/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py +++ b/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py @@ -1,13 +1,15 @@ """Contains utilities to create the action_callback_mapping.""" -from typing import Any, Callable, Union +from collections.abc import Iterable +from typing import Any, Callable, Union, cast from dash import Output, State, dcc from vizro.actions import _parameter, filter_interaction from vizro.managers import model_manager -from vizro.managers._model_manager import ModelID -from vizro.models import Action, Page +from vizro.managers._model_manager import FIGURE_MODELS, ModelID +from vizro.models import Action, Page, VizroBaseModel +from vizro.models._action._actions_chain import ActionsChain from vizro.models._controls import Filter, Parameter from vizro.models.types import ControlType @@ -15,13 +17,11 @@ # This function can also be reused for all other inputs (filters, parameters). # Potentially this could be a way to reconcile predefined with custom actions, # and make that predefined actions see and add into account custom actions. -def _get_matching_actions_by_function( - page_id: ModelID, action_function: Callable[[Any], dict[str, Any]] -) -> list[Action]: +def _get_matching_actions_by_function(page: Page, action_function: Callable[[Any], dict[str, Any]]) -> list[Action]: """Gets list of `Actions` on triggered `Page` that match the provided `action_function`.""" return [ action - for actions_chain in model_manager._get_page_actions_chains(page_id=page_id) + for actions_chain in cast(Iterable[ActionsChain], model_manager._get_models(ActionsChain, page)) for action in actions_chain.actions if action.function._function == action_function ] @@ -32,21 +32,27 @@ def _get_inputs_of_controls(page: Page, control_type: ControlType) -> list[State """Gets list of `States` for selected `control_type` of triggered `Page`.""" return [ State(component_id=control.selector.id, component_property=control.selector._input_property) - for control in page.controls - if isinstance(control, control_type) + for control in cast(Iterable[ControlType], model_manager._get_models(control_type, page)) ] +def _get_action_trigger(action: Action) -> VizroBaseModel: # type: ignore[return] + """Gets the model that triggers the action with "action_id".""" + from vizro.models._action._actions_chain import ActionsChain + + for actions_chain in cast(Iterable[ActionsChain], model_manager._get_models(ActionsChain)): + if action in actions_chain.actions: + return model_manager[ModelID(str(actions_chain.trigger.component_id))] + + def _get_inputs_of_figure_interactions( page: Page, action_function: Callable[[Any], dict[str, Any]] ) -> list[dict[str, State]]: """Gets list of `States` for selected chart interaction `action_function` of triggered `Page`.""" - figure_interactions_on_page = _get_matching_actions_by_function( - page_id=ModelID(str(page.id)), action_function=action_function - ) + figure_interactions_on_page = _get_matching_actions_by_function(page=page, action_function=action_function) inputs = [] for action in figure_interactions_on_page: - triggered_model = model_manager._get_action_trigger(action_id=ModelID(str(action.id))) + triggered_model = _get_action_trigger(action) required_attributes = ["_filter_interaction_input", "_filter_interaction"] for attribute in required_attributes: if not hasattr(triggered_model, attribute): @@ -60,9 +66,9 @@ def _get_inputs_of_figure_interactions( # TODO: Refactor this and util functions once we implement "_get_input_property" method in VizroBaseModel models -def _get_action_callback_inputs(action_id: ModelID) -> dict[str, list[Union[State, dict[str, State]]]]: +def _get_action_callback_inputs(action: Action) -> dict[str, list[Union[State, dict[str, State]]]]: """Creates mapping of pre-defined action names and a list of `States`.""" - page: Page = model_manager[model_manager._get_model_page_id(model_id=action_id)] + page = model_manager._get_model_page(action) action_input_mapping = { "filters": _get_inputs_of_controls(page=page, control_type=Filter), @@ -76,9 +82,9 @@ def _get_action_callback_inputs(action_id: ModelID) -> dict[str, list[Union[Stat # CALLBACK OUTPUTS -------------- -def _get_action_callback_outputs(action_id: ModelID) -> dict[str, Output]: +def _get_action_callback_outputs(action: Action) -> dict[str, Output]: """Creates mapping of target names and their `Output`.""" - action_function = model_manager[action_id].function._function + action_function = action.function._function # The right solution for mypy here is to not e.g. define new attributes on the base but instead to get mypy to # recognize that model_manager[action_id] is of type Action and hence has the function attribute. @@ -86,7 +92,7 @@ def _get_action_callback_outputs(action_id: ModelID) -> dict[str, Output]: # If not then we can do the cast to Action at the point of consumption here to avoid needing mypy ignores. try: - targets = model_manager[action_id].function["targets"] + targets = action.function["targets"] except KeyError: targets = [] @@ -103,23 +109,23 @@ def _get_action_callback_outputs(action_id: ModelID) -> dict[str, Output]: } -def _get_export_data_callback_outputs(action_id: ModelID) -> dict[str, Output]: +def _get_export_data_callback_outputs(action: Action) -> dict[str, Output]: """Gets mapping of relevant output target name and `Outputs` for `export_data` action.""" - action = model_manager[action_id] - try: targets = action.function["targets"] except KeyError: targets = None - if not targets: - targets = model_manager._get_page_model_ids_with_figure( - page_id=model_manager._get_model_page_id(model_id=action_id) + targets = targets or [ + model.id + for model in cast( + Iterable[VizroBaseModel], model_manager._get_models(FIGURE_MODELS, model_manager._get_model_page(action)) ) + ] return { f"download_dataframe_{target}": Output( - component_id={"type": "download_dataframe", "action_id": action_id, "target_id": target}, + component_id={"type": "download_dataframe", "action_id": action.id, "target_id": target}, component_property="data", ) for target in targets @@ -127,21 +133,21 @@ def _get_export_data_callback_outputs(action_id: ModelID) -> dict[str, Output]: # CALLBACK COMPONENTS -------------- -def _get_export_data_callback_components(action_id: ModelID) -> list[dcc.Download]: +def _get_export_data_callback_components(action: Action) -> list[dcc.Download]: """Creates dcc.Downloads for target components of the `export_data` action.""" - action = model_manager[action_id] - try: targets = action.function["targets"] except KeyError: targets = None - if not targets: - targets = model_manager._get_page_model_ids_with_figure( - page_id=model_manager._get_model_page_id(model_id=action_id) + targets = targets or [ + model.id + for model in cast( + Iterable[VizroBaseModel], model_manager._get_models(FIGURE_MODELS, model_manager._get_model_page(action)) ) + ] return [ - dcc.Download(id={"type": "download_dataframe", "action_id": action_id, "target_id": target}) + dcc.Download(id={"type": "download_dataframe", "action_id": action.id, "target_id": target}) for target in targets ] diff --git a/vizro-core/src/vizro/actions/_callback_mapping/_get_action_callback_mapping.py b/vizro-core/src/vizro/actions/_callback_mapping/_get_action_callback_mapping.py index 10841e18b..5b797835c 100644 --- a/vizro-core/src/vizro/actions/_callback_mapping/_get_action_callback_mapping.py +++ b/vizro-core/src/vizro/actions/_callback_mapping/_get_action_callback_mapping.py @@ -15,15 +15,12 @@ from vizro.actions._filter_action import _filter from vizro.actions._on_page_load_action import _on_page_load from vizro.actions._parameter_action import _parameter -from vizro.managers import model_manager -from vizro.managers._model_manager import ModelID +from vizro.models import Action -def _get_action_callback_mapping( - action_id: ModelID, argument: str -) -> Union[list[dcc.Download], dict[str, DashDependency]]: +def _get_action_callback_mapping(action: Action, argument: str) -> Union[list[dcc.Download], dict[str, DashDependency]]: """Creates mapping of action name and required callback input/output.""" - action_function = model_manager[action_id].function._function + action_function = action.function._function action_callback_mapping: dict[str, Any] = { export_data.__wrapped__: { @@ -50,4 +47,4 @@ def _get_action_callback_mapping( } action_call = action_callback_mapping.get(action_function, {}).get(argument) default_value: Union[list[dcc.Download], dict[str, DashDependency]] = [] if argument == "components" else {} - return default_value if not action_call else action_call(action_id=action_id) + return default_value if not action_call else action_call(action=action) diff --git a/vizro-core/src/vizro/managers/_model_manager.py b/vizro-core/src/vizro/managers/_model_manager.py index 14081681a..fe19a1e11 100644 --- a/vizro-core/src/vizro/managers/_model_manager.py +++ b/vizro-core/src/vizro/managers/_model_manager.py @@ -4,14 +4,14 @@ import random import uuid -from collections.abc import Generator -from typing import TYPE_CHECKING, NewType, Optional, TypeVar, cast +from collections.abc import Generator, Iterable +from typing import TYPE_CHECKING, NewType, Optional, TypeVar, Union, cast from vizro.managers._managers_utils import _state_modifier if TYPE_CHECKING: - from vizro.models import VizroBaseModel - from vizro.models._action._actions_chain import ActionsChain + from vizro.models import Page, VizroBaseModel + # As done for Dash components in dash.development.base_component, fixing the random seed is required to make sure that # the randomly generated model ID for the same model matches up across workers when running gunicorn without --preload. @@ -21,6 +21,13 @@ Model = TypeVar("Model", bound="VizroBaseModel") +# Sentinel object for models that are reactive to controls. This can't be done directly by defining +# FIGURE_MODELS = (Graph, ...) due to circular imports. Done as class for mypy. +# https://stackoverflow.com/questions/69239403/type-hinting-parameters-with-a-sentinel-value-as-the-default +class FIGURE_MODELS: + pass + + class DuplicateIDError(ValueError): """Useful for providing a more explicit error message when a model has id set automatically, e.g. Page.""" @@ -50,90 +57,71 @@ def __iter__(self) -> Generator[ModelID, None, None]: Note this yields model IDs rather key/value pairs to match the interface for a dictionary. """ + # TODO: should this yield models rather than model IDs? Should model_manager be more like set with a special + # lookup by model ID or more like dictionary? yield from self.__models - # TODO: Consider adding an option to iterate only through specific page - "in_page_with_id=None" - def _items_with_type(self, model_type: type[Model]) -> Generator[tuple[ModelID, Model], None, None]: - """Iterates through all models of type `model_type` (including subclasses).""" - for model_id in self: - if isinstance(self[model_id], model_type): - yield model_id, cast(Model, self[model_id]) - - # TODO: Consider returning with yield - # TODO: Make collection of model ids (throughout this file) to be set[ModelID]. - def _get_model_children(self, model_id: ModelID, all_model_ids: Optional[list[ModelID]] = None) -> list[ModelID]: - if all_model_ids is None: - all_model_ids = [] - - all_model_ids.append(model_id) - model = self[model_id] - if hasattr(model, "components"): - for child_model in model.components: - self._get_model_children(child_model.id, all_model_ids) - if hasattr(model, "tabs"): - for child_model in model.tabs: - self._get_model_children(child_model.id, all_model_ids) - return all_model_ids - - # TODO: Consider moving this method in the Dashboard model or some other util file - def _get_model_page_id(self, model_id: ModelID) -> ModelID: # type: ignore[return] - """Gets the id of the page containing the model with "model_id".""" + def _get_models( + self, + model_type: Optional[Union[type[Model], tuple[type[Model], ...], type[FIGURE_MODELS]]] = None, + page: Optional[Page] = None, + ) -> Generator[Model, None, None]: + """Iterates through all models of type `model_type` (including subclasses). + + If `model_type` not given then look at all models. If `page` specified then only give models from that page. + """ + import vizro.models as vm + + if model_type is FIGURE_MODELS: + model_type = (vm.Graph, vm.AgGrid, vm.Table, vm.Figure) + models = self.__get_model_children(page) if page is not None else self.__models.values() + + # Convert to list to avoid changing size when looping through at runtime. + for model in list(models): + if model_type is None or isinstance(model, model_type): + yield model + + def __get_model_children(self, model: Model) -> Generator[Model, None, None]: + """Iterates through children of `model`. + + Currently looks only through certain fields so might miss some children models. + """ + from vizro.models import VizroBaseModel + + if isinstance(model, VizroBaseModel): + yield model + + # TODO: in future this list should not be maintained manually. Instead we should look through all model children + # by looking at model.model_fields. + model_fields = ["components", "tabs", "controls", "actions", "selector"] + + for model_field in model_fields: + if (model_field_value := getattr(model, model_field, None)) is not None: + if isinstance(model_field_value, list): + # For fields like components that are list of models. + for single_model_field_value in model_field_value: + yield from self.__get_model_children(single_model_field_value) + else: + # For fields that have single model like selector. + yield from self.__get_model_children(model_field_value) + # We don't handle dicts of models at the moment. See below TODO for how this will all be improved in + # future. + + # TODO: Add navigation, accordions and other page objects. Won't be needed once have made whole model + # manager work better recursively and have better ways to navigate the hierarchy. In pydantic v2 this would use + # model_fields. Maybe we'd also use Page (or sometimes Dashboard) as the central model for navigating the + # hierarchy rather than it being so generic. + + def _get_model_page(self, model: Model) -> Page: # type: ignore[return] + """Gets the page containing `model`.""" from vizro.models import Page - for page_id, page in model_manager._items_with_type(Page): - page_model_ids = [page_id, self._get_model_children(model_id=page_id)] - - for actions_chain in self._get_page_actions_chains(page_id=page_id): - page_model_ids.append(actions_chain.id) - for action in actions_chain.actions: - page_model_ids.append(action.id) # noqa: PERF401 - - for control in page.controls: - page_model_ids.append(control.id) - if hasattr(control, "selector") and control.selector: - page_model_ids.append(control.selector.id) - - # TODO: Add navigation, accordions and other page objects - - if model_id in page_model_ids: - return cast(ModelID, page.id) - - # TODO: Increase the genericity of this method - def _get_page_actions_chains(self, page_id: ModelID) -> list[ActionsChain]: - """Gets all ActionsChains present on the page.""" - page = self[page_id] - page_actions_chains = [] - - for model_id in self._get_model_children(model_id=page_id): - model = self[model_id] - if hasattr(model, "actions"): - page_actions_chains.extend(model.actions) - - for control in page.controls: - if hasattr(control, "actions") and control.actions: - page_actions_chains.extend(control.actions) - if hasattr(control, "selector") and control.selector and hasattr(control.selector, "actions"): - page_actions_chains.extend(control.selector.actions) - - return page_actions_chains - - # TODO: Consider moving this one to the _callback_mapping_utils.py since it's only used there - def _get_action_trigger(self, action_id: ModelID) -> VizroBaseModel: # type: ignore[return] - """Gets the model that triggers the action with "action_id".""" - from vizro.models._action._actions_chain import ActionsChain - - for _, actions_chain in model_manager._items_with_type(ActionsChain): - if action_id in [action.id for action in actions_chain.actions]: - return self[ModelID(str(actions_chain.trigger.component_id))] - - def _get_page_model_ids_with_figure(self, page_id: ModelID) -> list[ModelID]: - """Gets ids of all components from the page that have a 'figure' registered.""" - return [ - model_id - for model_id in self._get_model_children(model_id=page_id) - # Optimally this statement should be: "if isinstance(model, Figure)" - if hasattr(model_manager[model_id], "figure") - ] + if isinstance(model, Page): + return model + + for page in cast(Iterable[Page], self._get_models(Page)): + if model in self.__get_model_children(page): + return page @staticmethod def _generate_id() -> ModelID: diff --git a/vizro-core/src/vizro/models/_action/_action.py b/vizro-core/src/vizro/models/_action/_action.py index 9aac00ca1..0aef7303b 100644 --- a/vizro-core/src/vizro/models/_action/_action.py +++ b/vizro-core/src/vizro/models/_action/_action.py @@ -11,7 +11,6 @@ except ImportError: # pragma: no cov from pydantic import Field, validator -from vizro.managers._model_manager import ModelID from vizro.models import VizroBaseModel from vizro.models._models_utils import _log_call from vizro.models.types import CapturedCallable @@ -79,7 +78,7 @@ def _get_callback_mapping(self): if self.inputs: callback_inputs = [State(*input.split(".")) for input in self.inputs] else: - callback_inputs = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="inputs") + callback_inputs = _get_action_callback_mapping(self, argument="inputs") callback_outputs: Union[list[Output], dict[str, Output]] if self.outputs: @@ -91,9 +90,9 @@ def _get_callback_mapping(self): if len(callback_outputs) == 1: callback_outputs = callback_outputs[0] else: - callback_outputs = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="outputs") + callback_outputs = _get_action_callback_mapping(self, argument="outputs") - action_components = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="components") + action_components = _get_action_callback_mapping(self, argument="components") return callback_inputs, callback_outputs, action_components diff --git a/vizro-core/src/vizro/models/_base.py b/vizro-core/src/vizro/models/_base.py index 40b00deb6..d2f47470e 100644 --- a/vizro-core/src/vizro/models/_base.py +++ b/vizro-core/src/vizro/models/_base.py @@ -127,7 +127,7 @@ def _extract_captured_callable_source() -> set[str]: # Check to see if the captured callable does use a cleaned module string, if yes then # we can assume that the source code can be imported via Vizro, and thus does not need to be defined value.__repr_clean__().startswith(new) - for _, new in REPLACEMENT_STRINGS.items() + for new in REPLACEMENT_STRINGS.values() ): try: source = textwrap.dedent(inspect.getsource(value._function)) diff --git a/vizro-core/src/vizro/models/_components/ag_grid.py b/vizro-core/src/vizro/models/_components/ag_grid.py index 0ef0e26cf..52b1e36a5 100644 --- a/vizro-core/src/vizro/models/_components/ag_grid.py +++ b/vizro-core/src/vizro/models/_components/ag_grid.py @@ -10,7 +10,7 @@ from pydantic import Field, PrivateAttr, validator from dash import ClientsideFunction, Input, Output, clientside_callback -from vizro.actions._actions_utils import CallbackTriggerDict, _get_component_actions, _get_parent_vizro_model +from vizro.actions._actions_utils import CallbackTriggerDict, _get_component_actions, _get_parent_model from vizro.managers import data_manager from vizro.models import Action, VizroBaseModel from vizro.models._action._actions_chain import _action_validator_factory @@ -100,7 +100,7 @@ def _filter_interaction( return data_frame # ctd_active_cell["id"] represents the underlying table id, so we need to fetch its parent Vizro Table actions. - source_table_actions = _get_component_actions(_get_parent_vizro_model(ctd_cellClicked["id"])) + source_table_actions = _get_component_actions(_get_parent_model(ctd_cellClicked["id"])) for action in source_table_actions: if action.function._function.__name__ != "filter_interaction" or target not in action.function["targets"]: diff --git a/vizro-core/src/vizro/models/_components/form/_text_area.py b/vizro-core/src/vizro/models/_components/form/_text_area.py index 5ac25fdb7..bb1ea7fa1 100644 --- a/vizro-core/src/vizro/models/_components/form/_text_area.py +++ b/vizro-core/src/vizro/models/_components/form/_text_area.py @@ -4,9 +4,9 @@ from dash import html try: - from pydantic.v1 import Field + from pydantic.v1 import Field, PrivateAttr except ImportError: # pragma: no cov - from pydantic import Field + from pydantic import Field, PrivateAttr from vizro.models import Action, VizroBaseModel from vizro.models._action._actions_chain import _action_validator_factory @@ -32,6 +32,9 @@ class TextArea(VizroBaseModel): placeholder: str = Field("", description="Default text to display in input field") actions: list[Action] = [] + # Component properties for actions and interactions + _input_property: str = PrivateAttr("value") + # Re-used validators # TODO: Before making public, consider how actions should be triggered and what the default property should be # See comment thread: https://github.com/mckinsey/vizro/pull/298#discussion_r1478137654 diff --git a/vizro-core/src/vizro/models/_components/form/_user_input.py b/vizro-core/src/vizro/models/_components/form/_user_input.py index bb98a14bb..7ca821f65 100644 --- a/vizro-core/src/vizro/models/_components/form/_user_input.py +++ b/vizro-core/src/vizro/models/_components/form/_user_input.py @@ -4,9 +4,9 @@ from dash import html try: - from pydantic.v1 import Field + from pydantic.v1 import Field, PrivateAttr except ImportError: # pragma: no cov - from pydantic import Field + from pydantic import Field, PrivateAttr from vizro.models import Action, VizroBaseModel from vizro.models._action._actions_chain import _action_validator_factory @@ -32,6 +32,9 @@ class UserInput(VizroBaseModel): placeholder: str = Field("", description="Default text to display in input field") actions: list[Action] = [] + # Component properties for actions and interactions + _input_property: str = PrivateAttr("value") + # Re-used validators # TODO: Before making public, consider how actions should be triggered and what the default property should be # See comment thread: https://github.com/mckinsey/vizro/pull/298#discussion_r1478137654 diff --git a/vizro-core/src/vizro/models/_components/table.py b/vizro-core/src/vizro/models/_components/table.py index 8ffb4ea8d..edfea4c5c 100644 --- a/vizro-core/src/vizro/models/_components/table.py +++ b/vizro-core/src/vizro/models/_components/table.py @@ -9,7 +9,7 @@ except ImportError: # pragma: no cov from pydantic import Field, PrivateAttr, validator -from vizro.actions._actions_utils import CallbackTriggerDict, _get_component_actions, _get_parent_vizro_model +from vizro.actions._actions_utils import CallbackTriggerDict, _get_component_actions, _get_parent_model from vizro.managers import data_manager from vizro.models import Action, VizroBaseModel from vizro.models._action._actions_chain import _action_validator_factory @@ -104,7 +104,7 @@ def _filter_interaction( return data_frame # ctd_active_cell["id"] represents the underlying table id, so we need to fetch its parent Vizro Table actions. - source_table_actions = _get_component_actions(_get_parent_vizro_model(ctd_active_cell["id"])) + source_table_actions = _get_component_actions(_get_parent_model(ctd_active_cell["id"])) for action in source_table_actions: if action.function._function.__name__ != "filter_interaction" or target not in action.function["targets"]: diff --git a/vizro-core/src/vizro/models/_controls/filter.py b/vizro-core/src/vizro/models/_controls/filter.py index 61f8ca523..8a18add8e 100644 --- a/vizro-core/src/vizro/models/_controls/filter.py +++ b/vizro-core/src/vizro/models/_controls/filter.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Literal, Union +from collections.abc import Iterable +from typing import Any, Literal, Union, cast import pandas as pd from dash import dcc @@ -17,7 +18,7 @@ from vizro.actions import _filter from vizro.managers import data_manager, model_manager from vizro.managers._data_manager import _DynamicData -from vizro.managers._model_manager import ModelID +from vizro.managers._model_manager import FIGURE_MODELS, ModelID from vizro.models import Action, VizroBaseModel from vizro.models._components.form import ( Checklist, @@ -136,10 +137,12 @@ def pre_build(self): # want to raise an error if the column is not found in a figure's data_frame, it will just be ignored. # This is the case when bool(self.targets) is False. # Possibly in future this will change (which would be breaking change). - proposed_targets = self.targets or model_manager._get_page_model_ids_with_figure( - page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id))) - ) - + proposed_targets = self.targets or [ + cast(ModelID, model.id) + for model in cast( + Iterable[VizroBaseModel], model_manager._get_models(FIGURE_MODELS, model_manager._get_model_page(self)) + ) + ] # TODO: Currently dynamic data functions require a default value for every argument. Even when there is a # dataframe parameter, the default value is used when pre-build the filter e.g. to find the targets, # column type (and hence selector) and initial values. There are three ways to handle this: diff --git a/vizro-core/src/vizro/models/_controls/parameter.py b/vizro-core/src/vizro/models/_controls/parameter.py index 9663c3f14..cdc76936c 100644 --- a/vizro-core/src/vizro/models/_controls/parameter.py +++ b/vizro-core/src/vizro/models/_controls/parameter.py @@ -1,4 +1,5 @@ -from typing import Literal +from collections.abc import Iterable +from typing import Literal, cast try: from pydantic.v1 import Field, validator @@ -61,7 +62,7 @@ def check_data_frame_as_target_argument(cls, target): @validator("targets") def check_duplicate_parameter_target(cls, targets): all_targets = targets.copy() - for _, param in model_manager._items_with_type(Parameter): + for param in cast(Iterable[Parameter], model_manager._get_models(Parameter)): all_targets.extend(param.targets) duplicate_targets = {item for item in all_targets if all_targets.count(item) > 1} if duplicate_targets: diff --git a/vizro-core/src/vizro/models/_navigation/_navigation_utils.py b/vizro-core/src/vizro/models/_navigation/_navigation_utils.py index 7e8e1f15f..387cf8b9a 100644 --- a/vizro-core/src/vizro/models/_navigation/_navigation_utils.py +++ b/vizro-core/src/vizro/models/_navigation/_navigation_utils.py @@ -1,7 +1,8 @@ from __future__ import annotations import itertools -from typing import TypedDict +from collections.abc import Iterable +from typing import TypedDict, cast import dash_bootstrap_components as dbc @@ -15,8 +16,7 @@ def _validate_pages(pages): pages_as_list = list(itertools.chain(*pages.values())) if isinstance(pages, dict) else pages # Ideally we would use dashboard.pages in the model manager here, but we only register pages in # dashboard.pre_build and model manager cannot find a Dashboard at validation time. - # page[0] gives the page model ID. - registered_pages = [page[0] for page in model_manager._items_with_type(Page)] + registered_pages = [page.id for page in cast(Iterable[Page], model_manager._get_models(Page))] if not pages_as_list: raise ValueError("Ensure this value has at least 1 item.") diff --git a/vizro-core/src/vizro/models/_page.py b/vizro-core/src/vizro/models/_page.py index 329b4a279..e137987bb 100644 --- a/vizro-core/src/vizro/models/_page.py +++ b/vizro-core/src/vizro/models/_page.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Mapping -from typing import Any, Optional, TypedDict, Union +from collections.abc import Iterable, Mapping +from typing import Any, Optional, TypedDict, Union, cast from dash import dcc, html @@ -13,8 +13,8 @@ from vizro._constants import ON_PAGE_LOAD_ACTION_PREFIX from vizro.actions import _on_page_load from vizro.managers import model_manager -from vizro.managers._model_manager import DuplicateIDError, ModelID -from vizro.models import Action, Layout, VizroBaseModel +from vizro.managers._model_manager import FIGURE_MODELS, DuplicateIDError +from vizro.models import Action, Filter, Layout, VizroBaseModel from vizro.models._action._actions_chain import ActionsChain, Trigger from vizro.models._layout import set_layout from vizro.models._models_utils import _log_call, check_captured_callable, validate_min_length @@ -96,10 +96,15 @@ def __vizro_exclude_fields__(self) -> Optional[Union[set[str], Mapping[str, Any] @_log_call def pre_build(self): - targets = model_manager._get_page_model_ids_with_figure(page_id=ModelID(str(self.id))) - - # TODO NEXT: make work generically for control group - targets.extend(control.id for control in self.controls if getattr(control, "_dynamic", False)) + figure_targets = [ + model.id for model in cast(Iterable[VizroBaseModel], model_manager._get_models(FIGURE_MODELS, page=self)) + ] + filter_targets = [ + filter.id + for filter in cast(Iterable[Filter], model_manager._get_models(Filter, page=self)) + if filter._dynamic + ] + targets = figure_targets + filter_targets if targets: self.actions = [ diff --git a/vizro-core/tests/unit/vizro/actions/_action_loop/test_get_action_loop_components.py b/vizro-core/tests/unit/vizro/actions/_action_loop/test_get_action_loop_components.py index 9e649188d..55838f186 100644 --- a/vizro-core/tests/unit/vizro/actions/_action_loop/test_get_action_loop_components.py +++ b/vizro-core/tests/unit/vizro/actions/_action_loop/test_get_action_loop_components.py @@ -1,8 +1,9 @@ """Unit tests for vizro.actions._action_loop._get_action_loop_components file.""" import pytest -from asserts import assert_component_equal +from asserts import STRIP_ALL, assert_component_equal from dash import dcc, html +from dash._utils import stringify_id import vizro.models as vm import vizro.plotly.express as px @@ -27,7 +28,7 @@ def gateway_components(request): components = request.param actions_chain_ids = [model_manager[component].actions[0].id for component in components] return [ - dcc.Store(id={"type": "gateway_input", "trigger_id": actions_chain_id}, data=f"{actions_chain_id}") + dcc.Store(id={"type": "gateway_input", "trigger_id": actions_chain_id}, data=actions_chain_id) for actions_chain_id in actions_chain_ids ] @@ -153,11 +154,26 @@ def test_all_action_loop_components( result = _get_action_loop_components() expected = html.Div( id="action_loop_components_div", - children=fundamental_components - + gateway_components - + action_trigger_components - + [action_trigger_actions_id_component] - + [trigger_to_actions_chain_mapper_component], + children=[ + *fundamental_components, + *gateway_components, + *action_trigger_components, + action_trigger_actions_id_component, + trigger_to_actions_chain_mapper_component, + ], ) - assert_component_equal(result, expected) + # Data in these dcc.Stores is arbitrarily order. Sort in advance to ensure that assert_component_equal + # is order-agnostic for their data. + for key in ("action_trigger_actions_id", "trigger_to_actions_chain_mapper"): + result[key].data = sorted(result[key].data) + expected[key].data = sorted(expected[key].data) + + # Validate the action_loop_components_div wrapper. + assert_component_equal(result, expected, keys_to_strip=STRIP_ALL) + + # Order of dcc.Stores inside div wrapper is arbitrary so sort by stringified id to do order-agnostic comparison. + assert_component_equal( + sorted(result.children, key=lambda component: stringify_id(component.id)), + sorted(expected.children, key=lambda component: stringify_id(component.id)), + ) diff --git a/vizro-core/tests/unit/vizro/actions/_callback_mapping/test_get_action_callback_mapping.py b/vizro-core/tests/unit/vizro/actions/_callback_mapping/test_get_action_callback_mapping.py index 27adbb223..c32b8e1c7 100644 --- a/vizro-core/tests/unit/vizro/actions/_callback_mapping/test_get_action_callback_mapping.py +++ b/vizro-core/tests/unit/vizro/actions/_callback_mapping/test_get_action_callback_mapping.py @@ -9,6 +9,7 @@ from vizro import Vizro from vizro.actions import export_data, filter_interaction from vizro.actions._callback_mapping._get_action_callback_mapping import _get_action_callback_mapping +from vizro.managers import model_manager from vizro.models.types import capture @@ -185,7 +186,7 @@ class TestCallbackMapping: ], ) def test_action_callback_mapping_inputs(self, action_id, action_callback_inputs_expected): - result = _get_action_callback_mapping(action_id=action_id, argument="inputs") + result = _get_action_callback_mapping(action=model_manager[action_id], argument="inputs") assert result == action_callback_inputs_expected @pytest.mark.parametrize( @@ -242,14 +243,14 @@ def test_action_callback_mapping_inputs(self, action_id, action_callback_inputs_ indirect=["action_callback_outputs_expected"], ) def test_action_callback_mapping_outputs(self, action_id, action_callback_outputs_expected): - result = _get_action_callback_mapping(action_id=action_id, argument="outputs") + result = _get_action_callback_mapping(action=model_manager[action_id], argument="outputs") assert result == action_callback_outputs_expected @pytest.mark.parametrize( "export_data_outputs_expected", [("scatter_chart", "scatter_chart_2", "vizro_table")], indirect=True ) def test_export_data_no_targets_set_mapping_outputs(self, export_data_outputs_expected): - result = _get_action_callback_mapping(action_id="export_data_action", argument="outputs") + result = _get_action_callback_mapping(action=model_manager["export_data_action"], argument="outputs") assert result == export_data_outputs_expected @@ -266,7 +267,7 @@ def test_export_data_no_targets_set_mapping_outputs(self, export_data_outputs_ex def test_export_data_targets_set_mapping_outputs( self, config_for_testing_all_components_with_actions, export_data_outputs_expected ): - result = _get_action_callback_mapping(action_id="export_data_action", argument="outputs") + result = _get_action_callback_mapping(action=model_manager["export_data_action"], argument="outputs") assert result == export_data_outputs_expected @@ -274,7 +275,9 @@ def test_export_data_targets_set_mapping_outputs( "export_data_components_expected", [("scatter_chart", "scatter_chart_2", "vizro_table")], indirect=True ) def test_export_data_no_targets_set_mapping_components(self, export_data_components_expected): - result_components = _get_action_callback_mapping(action_id="export_data_action", argument="components") + result_components = _get_action_callback_mapping( + action=model_manager["export_data_action"], argument="components" + ) assert_component_equal(result_components, export_data_components_expected) @pytest.mark.parametrize( @@ -290,11 +293,13 @@ def test_export_data_no_targets_set_mapping_components(self, export_data_compone def test_export_data_targets_set_mapping_components( self, config_for_testing_all_components_with_actions, export_data_components_expected ): - result_components = _get_action_callback_mapping(action_id="export_data_action", argument="components") + result_components = _get_action_callback_mapping( + action=model_manager["export_data_action"], argument="components" + ) assert_component_equal(result_components, export_data_components_expected) def test_known_action_unknown_argument(self): - result = _get_action_callback_mapping(action_id="export_data_action", argument="unknown-argument") + result = _get_action_callback_mapping(action=model_manager["export_data_action"], argument="unknown-argument") assert result == {} # "export_data_custom_action" represents a unique scenario within custom actions, where the function's name @@ -305,5 +310,5 @@ def test_known_action_unknown_argument(self): "argument, expected", [("inputs", {}), ("outputs", {}), ("components", []), ("unknown-argument", {})] ) def test_custom_action_mapping(self, action_id, argument, expected): - result = _get_action_callback_mapping(action_id=action_id, argument=argument) + result = _get_action_callback_mapping(action=model_manager[action_id], argument=argument) assert result == expected diff --git a/vizro-core/tests/unit/vizro/models/_controls/test_filter.py b/vizro-core/tests/unit/vizro/models/_controls/test_filter.py index 823b88a91..2b64666e9 100644 --- a/vizro-core/tests/unit/vizro/models/_controls/test_filter.py +++ b/vizro-core/tests/unit/vizro/models/_controls/test_filter.py @@ -841,6 +841,7 @@ class TestFilterBuild: def test_filter_build(self, test_column, test_selector): filter = vm.Filter(column=test_column, selector=test_selector) model_manager["test_page"].controls = [filter] + filter.pre_build() result = filter.build() expected = test_selector.build()