diff --git a/vizro-core/examples/default/app.py b/vizro-core/examples/default/app.py index 96b40af30..ed8c69843 100644 --- a/vizro-core/examples/default/app.py +++ b/vizro-core/examples/default/app.py @@ -8,508 +8,532 @@ df_gapminder = px.data.gapminder() -single_tabs = vm.Page( - title="Single Tabs", - components=[ - vm.Tabs( - id="first_tab", - tabs=[ - vm.Container( - id="tab-1", - title="Tab I", - components=[ - vm.Graph( - id="graph_1", - figure=px.line( - df_gapminder, - title="Graph_1", - x="year", - y="lifeExp", - color="continent", - line_group="country", - hover_name="country", + +def single_tabs(): + single_tabs = vm.Page( + title="Single Tabs", + components=[ + vm.Tabs( + id="first_tab", + tabs=[ + vm.Container( + id="tab-1", + title="Tab I", + components=[ + vm.Graph( + id="graph_1", + figure=px.line( + df_gapminder, + title="Graph_1", + x="year", + y="lifeExp", + color="continent", + line_group="country", + hover_name="country", + ), ), - ), - vm.Graph( - id="graph_2", - figure=px.scatter( - df_gapminder, - title="Graph_2", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + vm.Graph( + id="graph_2", + figure=px.scatter( + df_gapminder, + title="Graph_2", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - vm.Graph( - id="graph_3", - figure=px.box( - df_gapminder, - title="Graph_3", - x="continent", - y="lifeExp", - color="continent", + vm.Graph( + id="graph_3", + figure=px.box( + df_gapminder, + title="Graph_3", + x="continent", + y="lifeExp", + color="continent", + ), ), - ), - vm.Button( - text="Export data", - actions=[ - vm.Action(function=export_data()), - ], - ), - ], - ), - vm.Container( - id="tab-2", - title="Tab II", - components=[ - vm.Graph( - id="graph_4", - figure=px.scatter( - df_gapminder, - title="Graph_4", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + vm.Button( + text="Export data", + actions=[ + vm.Action(function=export_data()), + ], ), - ), - vm.Graph( - id="graph_5", - figure=px.box( - df_gapminder, - title="Graph_5", - x="continent", - y="lifeExp", - color="continent", + ], + ), + vm.Container( + id="tab-2", + title="Tab II", + components=[ + vm.Graph( + id="graph_4", + figure=px.scatter( + df_gapminder, + title="Graph_4", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - vm.Graph( - id="graph_6", - figure=px.line( - df_gapminder, - title="Graph_6", - x="year", - y="lifeExp", - color="continent", - line_group="country", - hover_name="country", + vm.Graph( + id="graph_5", + figure=px.box( + df_gapminder, + title="Graph_5", + x="continent", + y="lifeExp", + color="continent", + ), + ), + vm.Graph( + id="graph_6", + figure=px.line( + df_gapminder, + title="Graph_6", + x="year", + y="lifeExp", + color="continent", + line_group="country", + hover_name="country", + ), ), + ], + ), + ], + ), + ], + controls=[ + vm.Parameter( + targets=[ + "graph_1.y", + "graph_2.y", + "graph_3.y", + "graph_4.y", + "graph_5.y", + "graph_6.y", + ], + selector=vm.RadioItems(options=["lifeExp", "pop", "gdpPercap"], title="Select variable"), + ), + vm.Filter(column="continent"), + ], + ) + return single_tabs + + +def multiple_containers_custom_layout(): + multiple_containers_custom_layout = vm.Page( + title="Multiple Containers - custom layout", + components=[ + vm.Container( + id="cont_1", + title="Container I", + layout=vm.Layout(grid=[[0, 1]]), + components=[ + vm.Graph( + id="graph_11", + figure=px.line( + df_gapminder, + title="Graph_11", + x="year", + y="lifeExp", + color="continent", + line_group="country", + hover_name="country", ), - ], - ), - ], - ), - ], - controls=[ - vm.Parameter( - targets=[ - "graph_1.y", - "graph_2.y", - "graph_3.y", - "graph_4.y", - "graph_5.y", - "graph_6.y", - ], - selector=vm.RadioItems(options=["lifeExp", "pop", "gdpPercap"], title="Select variable"), - ), - vm.Filter(column="continent"), - ], -) -multiple_containers_custom_layout = vm.Page( - title="Multiple Containers - custom layout", - components=[ - vm.Container( - id="cont_1", - title="Container I", - layout=vm.Layout(grid=[[0, 1]]), - components=[ - vm.Graph( - id="graph_11", - figure=px.line( - df_gapminder, - title="Graph_11", - x="year", - y="lifeExp", - color="continent", - line_group="country", - hover_name="country", ), - ), - vm.Graph( - id="graph_12", - figure=px.scatter( - df_gapminder, - title="Graph_12", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + vm.Graph( + id="graph_12", + figure=px.scatter( + df_gapminder, + title="Graph_12", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - ], - ), - vm.Container( - id="cont_2", - title="Container II", - layout=vm.Layout(grid=[[0, 1]], row_min_height="300px"), - components=[ - vm.Graph( - id="graph_13", - figure=px.box( - df_gapminder, - title="Graph_13", - x="continent", - y="lifeExp", - color="continent", + ], + ), + vm.Container( + id="cont_2", + title="Container II", + layout=vm.Layout(grid=[[0, 1]], row_min_height="300px"), + components=[ + vm.Graph( + id="graph_13", + figure=px.box( + df_gapminder, + title="Graph_13", + x="continent", + y="lifeExp", + color="continent", + ), ), - ), - vm.Button( - text="Export data", - actions=[ - vm.Action(function=export_data()), - ], - ), - ], - ), - ], - controls=[ - vm.Filter(column="continent"), - ], -) -single_tabs_action = vm.Page( - title="Single Tabs - with action", - components=[ - vm.Tabs( - tabs=[ - vm.Container( - id="container_table", - title="Tab I Table", - components=[ - vm.Table( - id="table-1", - figure=dash_data_table( - id="dash_datatable-1", - data_frame=df_gapminder, + vm.Button( + text="Export data", + actions=[ + vm.Action(function=export_data()), + ], + ), + ], + ), + ], + controls=[ + vm.Filter(column="continent"), + ], + ) + return multiple_containers_custom_layout + + +def single_tabs_action(): + single_tabs_action = vm.Page( + title="Single Tabs - with action", + components=[ + vm.Tabs( + tabs=[ + vm.Container( + id="container_table", + title="Tab I Table", + components=[ + vm.Table( + id="table-1", + figure=dash_data_table( + id="dash_datatable-1", + data_frame=df_gapminder, + ), ), - ), - vm.Button( - text="Export data", - actions=[ - vm.Action(function=export_data()), - ], - ), - ], - ) - ] - ), - ], - controls=[ - vm.Filter(column="continent"), - ], -) + vm.Button( + text="Export data", + actions=[ + vm.Action(function=export_data()), + ], + ), + ], + ) + ] + ), + ], + controls=[ + vm.Filter(column="continent"), + ], + ) + return single_tabs_action + -multiple_tabs = vm.Page( - id="page_4", - title="Multiple Tabs", - components=[ - vm.Tabs( - id="page-4-tab1", - tabs=[ - vm.Container( - title="Tab 1 container 1", - components=[ - vm.Graph( - id="graph_44", - figure=px.scatter( - df_gapminder, - title="Graph_44", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", +def multiple_tabs(): + multiple_tabs = vm.Page( + id="page_4", + title="Multiple Tabs", + components=[ + vm.Tabs( + id="page-4-tab1", + tabs=[ + vm.Container( + title="Tab 1 container 1", + components=[ + vm.Graph( + id="graph_44", + figure=px.scatter( + df_gapminder, + title="Graph_44", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - ], - ), - vm.Container( - title="Tab 1 container 2", - components=[ - vm.Graph( - id="graph_441", - figure=px.scatter( - df_gapminder, - title="Graph_441", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + ], + ), + vm.Container( + title="Tab 1 container 2", + components=[ + vm.Graph( + id="graph_441", + figure=px.scatter( + df_gapminder, + title="Graph_441", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - ], - ), - ], - ), - vm.Tabs( - id="page-4-tab2", - tabs=[ - vm.Container( - title="Tab 2 container", - components=[ - vm.Graph( - id="graph_45", - figure=px.scatter( - df_gapminder, - title="Graph_45", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + ], + ), + ], + ), + vm.Tabs( + id="page-4-tab2", + tabs=[ + vm.Container( + title="Tab 2 container", + components=[ + vm.Graph( + id="graph_45", + figure=px.scatter( + df_gapminder, + title="Graph_45", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - ], - ), - vm.Container( - title="Tab 2 container 2", - components=[ - vm.Graph( - id="graph_451", - figure=px.scatter( - df_gapminder, - title="Graph_451", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + ], + ), + vm.Container( + title="Tab 2 container 2", + components=[ + vm.Graph( + id="graph_451", + figure=px.scatter( + df_gapminder, + title="Graph_451", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), + ], + ), + ], + ), + ], + ) + return multiple_tabs + + +def single_tabs_custom_layout(): + single_tabs_custom_layout = vm.Page( + title="Single Tabs - custom layout", + components=[ + vm.Tabs( + id="first-tabr", + tabs=[ + vm.Container( + layout=vm.Layout( + grid=[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 2, 2], [1, 1, 2, 2], [3, -1, -1, -1]] ), - ], - ), - ], - ), - ], -) -single_tabs_custom_layout = vm.Page( - title="Single Tabs - custom layout", - components=[ - vm.Tabs( - id="first-tabr", - tabs=[ - vm.Container( - layout=vm.Layout(grid=[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 2, 2], [1, 1, 2, 2], [3, -1, -1, -1]]), - id="tab-1r", - title="Tab I Title", - components=[ - vm.Graph( - id="graph_1r", - figure=px.line( - df_gapminder, - title="Graph_1", - x="year", - y="lifeExp", - color="continent", - line_group="country", - hover_name="country", + id="tab-1r", + title="Tab I Title", + components=[ + vm.Graph( + id="graph_1r", + figure=px.line( + df_gapminder, + title="Graph_1", + x="year", + y="lifeExp", + color="continent", + line_group="country", + hover_name="country", + ), ), - ), - vm.Graph( - id="graph_2r", - figure=px.scatter( - df_gapminder, - title="Graph_2", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + vm.Graph( + id="graph_2r", + figure=px.scatter( + df_gapminder, + title="Graph_2", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - vm.Graph( - id="graph_3r", - figure=px.box( - df_gapminder, - title="Graph_3", - x="continent", - y="lifeExp", - color="continent", + vm.Graph( + id="graph_3r", + figure=px.box( + df_gapminder, + title="Graph_3", + x="continent", + y="lifeExp", + color="continent", + ), ), - ), - vm.Button( - text="Export data", - actions=[ - vm.Action(function=export_data()), - ], - ), - ], - ), - vm.Container( - id="tab-2r", - title="Tab II", - layout=vm.Layout( - grid=[ - [0, 0, 0, 0], - [0, 0, 0, 0], - [1, 1, 2, 2], - [1, 1, 2, 2], - ] - ), - components=[ - vm.Graph( - id="graph_4r", - figure=px.scatter( - df_gapminder, - title="Graph_4", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + vm.Button( + text="Export data", + actions=[ + vm.Action(function=export_data()), + ], ), + ], + ), + vm.Container( + id="tab-2r", + title="Tab II", + layout=vm.Layout( + grid=[ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 1, 2, 2], + [1, 1, 2, 2], + ] ), - vm.Graph( - id="graph_5r", - figure=px.box( - df_gapminder, - title="Graph_5", - x="continent", - y="lifeExp", - color="continent", + components=[ + vm.Graph( + id="graph_4r", + figure=px.scatter( + df_gapminder, + title="Graph_4", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - vm.Graph( - id="graph_6r", - figure=px.line( - df_gapminder, - title="Graph_6", - x="year", - y="lifeExp", - color="continent", - line_group="country", - hover_name="country", + vm.Graph( + id="graph_5r", + figure=px.box( + df_gapminder, + title="Graph_5", + x="continent", + y="lifeExp", + color="continent", + ), ), - ), - ], - ), - ], - ), - ], -) - -multiple_containers_nested = vm.Page( - id="page_6", - title="Multiple Containers - Nested", - layout=vm.Layout( - grid=[ - [0, 0, 0, 0], - [0, 0, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1], - ] - ), - components=[ - vm.Container( - layout=vm.Layout(grid=[[0, 1], [0, 1]]), - components=[ - vm.Container( - layout=vm.Layout( - grid=[ - [0, 0, 1, 1], - [0, 0, 1, 1], - [2, 2, 3, 3], - [2, 2, 3, 3], - ] - ), - components=[ - vm.Graph( - id="graph_1rn", - figure=px.line( - df_gapminder, - title="Graph_1", - x="year", - y="lifeExp", - color="continent", - line_group="country", - hover_name="country", + vm.Graph( + id="graph_6r", + figure=px.line( + df_gapminder, + title="Graph_6", + x="year", + y="lifeExp", + color="continent", + line_group="country", + hover_name="country", + ), ), + ], + ), + ], + ), + ], + ) + return single_tabs_custom_layout + + +def multiple_containers_nested(): + multiple_containers_nested = vm.Page( + id="page_6", + title="Multiple Containers - Nested", + layout=vm.Layout( + grid=[ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + ] + ), + components=[ + vm.Container( + layout=vm.Layout(grid=[[0, 1], [0, 1]]), + components=[ + vm.Container( + layout=vm.Layout( + grid=[ + [0, 0, 1, 1], + [0, 0, 1, 1], + [2, 2, 3, 3], + [2, 2, 3, 3], + ] ), - vm.Graph( - id="graph_2rn", - figure=px.scatter( - df_gapminder, - title="Graph_2", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + components=[ + vm.Graph( + id="graph_1rn", + figure=px.line( + df_gapminder, + title="Graph_1", + x="year", + y="lifeExp", + color="continent", + line_group="country", + hover_name="country", + ), ), - ), - vm.Graph( - id="graph_3rn", - figure=px.box( - df_gapminder, - title="Graph_3", - x="continent", - y="lifeExp", - color="continent", + vm.Graph( + id="graph_2rn", + figure=px.scatter( + df_gapminder, + title="Graph_2", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ), - vm.Graph( - id="graph_2rnn", - figure=px.scatter( - df_gapminder, - title="Graph_2", - x="gdpPercap", - y="lifeExp", - size="pop", - color="continent", + vm.Graph( + id="graph_3rn", + figure=px.box( + df_gapminder, + title="Graph_3", + x="continent", + y="lifeExp", + color="continent", + ), ), - ), - ], - ), - vm.Container( - components=[ - vm.Graph( - id="graph_6rn", - figure=px.line( - df_gapminder, - title="Graph_6", - x="year", - y="lifeExp", - color="continent", - line_group="country", - hover_name="country", + vm.Graph( + id="graph_2rnn", + figure=px.scatter( + df_gapminder, + title="Graph_2", + x="gdpPercap", + y="lifeExp", + size="pop", + color="continent", + ), ), - ) - ] - ), - ], - ), - vm.Container( - title="Second container", - components=[ - vm.Graph( - id="graph_6rnn", - figure=px.line( - df_gapminder, - title="Graph_6", - x="year", - y="lifeExp", - color="continent", - line_group="country", - hover_name="country", + ], ), - ) - ], - ), - ], -) + vm.Container( + components=[ + vm.Graph( + id="graph_6rn", + figure=px.line( + df_gapminder, + title="Graph_6", + x="year", + y="lifeExp", + color="continent", + line_group="country", + hover_name="country", + ), + ) + ] + ), + ], + ), + vm.Container( + title="Second container", + components=[ + vm.Graph( + id="graph_6rnn", + figure=px.line( + df_gapminder, + title="Graph_6", + x="year", + y="lifeExp", + color="continent", + line_group="country", + hover_name="country", + ), + ) + ], + ), + ], + ) + return multiple_containers_nested + dashboard = vm.Dashboard( pages=[ - single_tabs, - single_tabs_custom_layout, - single_tabs_action, - multiple_tabs, - multiple_containers_custom_layout, - multiple_containers_nested, + single_tabs(), + single_tabs_custom_layout(), + single_tabs_action(), + multiple_tabs(), + multiple_containers_custom_layout(), + multiple_containers_nested(), ] ) diff --git a/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py b/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py index 379a066ec..dceba7d76 100644 --- a/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py +++ b/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py @@ -2,37 +2,17 @@ from __future__ import annotations -from itertools import chain from typing import TYPE_CHECKING, List from dash import page_registry from vizro.managers import model_manager -from vizro.models import VizroBaseModel if TYPE_CHECKING: from vizro.models import Action, Page from vizro.models._action._actions_chain import ActionsChain -def _get_actions(model: VizroBaseModel) -> List[ActionsChain]: - """Gets the list of the ActionsChain models for any VizroBaseModel model.""" - if hasattr(model, "selector"): - return model.selector.actions - elif hasattr(model, "actions"): - return model.actions - return [] - - -def _get_all_actions_chains_on_page(page: Page) -> List[ActionsChain]: - """Gets the list of the ActionsChain models for the Page model.""" - return [ - actions_chain - for page_item in chain([page], page.components, page.controls) - for actions_chain in _get_actions(model=page_item) - ] - - def _get_actions_chains_on_registered_pages() -> List[ActionsChain]: """Gets list of ActionsChain models for registered pages.""" actions_chains: List[ActionsChain] = [] @@ -41,7 +21,7 @@ def _get_actions_chains_on_registered_pages() -> List[ActionsChain]: page: Page = model_manager[registered_page["module"]] except KeyError: continue - actions_chains.extend(_get_all_actions_chains_on_page(page=page)) + actions_chains.extend(page._get_page_actions_chains()) return actions_chains diff --git a/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py b/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py index a8fae095a..b7a58d51a 100644 --- a/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py +++ b/vizro-core/src/vizro/actions/_callback_mapping/_callback_mapping_utils.py @@ -1,91 +1,30 @@ """Contains utilities to create the action_callback_mapping.""" -from itertools import chain -from typing import Any, Callable, Dict, List, NamedTuple +from typing import Any, Callable, Dict, List, Union from dash import Output, State, dcc -from vizro.actions import _on_page_load, _parameter, export_data, filter_interaction -from vizro.managers import data_manager, model_manager +from vizro.actions import _parameter, export_data, filter_interaction +from vizro.managers import model_manager from vizro.managers._model_manager import ModelID -from vizro.models import Action, Page, Table, VizroBaseModel -from vizro.models._action._actions_chain import ActionsChain +from vizro.models import Action, Page, Table from vizro.models._controls import Filter, Parameter from vizro.models.types import ControlType -class ModelActionsChains(NamedTuple): - model: VizroBaseModel - actions_chains: List[ActionsChain] - - -def _get_actions(model) -> List[ActionsChain]: - """Gets the list of trigger action chains in the `action` parameter for any model.""" - if hasattr(model, "selector"): - return model.selector.actions - elif hasattr(model, "actions"): - return model.actions - return [] - - -def _get_all_actions_chains_on_page(page: Page) -> chain: # type: ignore[type-arg] - """Creates an itertools chain of all ActionsChains present on selected Page.""" - return chain(*(_get_actions(page_item) for page_item in chain([page], page.components, page.controls))) - - -def _get_model_actions_chains_mapping(page: Page) -> Dict[str, ModelActionsChains]: - """Creates a mapping of model ids and ModelActionsChains for selected Page.""" - model_actions_chains_mapping = {} - for page_item in chain([page], page.components, page.controls): - model_actions_chains_mapping[page_item.id] = ModelActionsChains( - model=page_item, actions_chains=_get_actions(page_item) - ) - return model_actions_chains_mapping - - -def _get_triggered_page(action_id: ModelID) -> Page: # type: ignore[return] - """Gets the page where the provided `action_id` has been triggered.""" - for _, page in model_manager._items_with_type(Page): - if any( - action.id == action_id - for actions_chain in _get_all_actions_chains_on_page(page) - for action in actions_chain.actions - ): - return page - - -def _get_triggered_model(action_id: ModelID) -> VizroBaseModel: # type: ignore[return] - """Gets the model where the provided `action_id` has been triggered.""" - for _, page in model_manager._items_with_type(Page): - for model_id, model_actions_chains in _get_model_actions_chains_mapping(page).items(): - if any( - action.id == action_id - for actions_chain in model_actions_chains.actions_chains - for action in actions_chain.actions - ): - return model_actions_chains.model - - -def _get_components_with_data(action_id: ModelID) -> List[str]: - """Gets all components that have a registered dataframe on the page where `action_id` was triggered.""" - page = _get_triggered_page(action_id=action_id) - return [component.id for component in page.components if data_manager._has_registered_data(component.id)] - - def _get_matching_actions_by_function(page: Page, action_function: Callable[[Any], Dict[str, Any]]) -> List[Action]: """Gets list of Actions on triggered page that match the provided action function.""" return [ action - for actions_chain in _get_all_actions_chains_on_page(page) + for actions_chain in page._get_page_actions_chains() for action in actions_chain.actions if action.function._function == action_function ] # CALLBACK STATES -------------- -def _get_inputs_of_controls(action_id: ModelID, control_type: ControlType) -> List[State]: +def _get_inputs_of_controls(page: Page, control_type: ControlType) -> List[State]: """Gets list of States for selected control_type of triggered page.""" - page = _get_triggered_page(action_id=action_id) return [ State( component_id=control.selector.id, @@ -97,17 +36,16 @@ def _get_inputs_of_controls(action_id: ModelID, control_type: ControlType) -> Li def _get_inputs_of_figure_interactions( - action_id: ModelID, action_function: Callable[[Any], Dict[str, Any]] + page: Page, action_function: Callable[[Any], Dict[str, Any]] ) -> List[Dict[str, State]]: """Gets list of States for selected chart interaction `action_name` of triggered page.""" figure_interactions_on_page = _get_matching_actions_by_function( - page=_get_triggered_page(action_id=action_id), + page=page, action_function=action_function, ) inputs = [] for action in figure_interactions_on_page: - # TODO: Consider do we want to move the following logic into Model implementation - triggered_model = _get_triggered_model(action_id=ModelID(str(action.id))) + triggered_model = model_manager._get_action_trigger(action_id=ModelID(str(action.id))) if isinstance(triggered_model, Table): inputs.append( { @@ -130,9 +68,11 @@ def _get_inputs_of_figure_interactions( return inputs -def _get_action_callback_inputs(action_id: ModelID) -> Dict[str, Any]: +# TODO: Refactor this and util functions once we implement "_get_input_property" method in VizroBaseModel models +def _get_action_callback_inputs(action_id: ModelID) -> Dict[str, List[Union[State, Dict[str, State]]]]: """Creates mapping of pre-defined action names and a list of States.""" action_function = model_manager[action_id].function._function + page: Page = model_manager._get_model_page(model_id=action_id) if action_function == export_data.__wrapped__: include_inputs = ["filters", "filter_interaction"] @@ -140,17 +80,13 @@ def _get_action_callback_inputs(action_id: ModelID) -> Dict[str, Any]: include_inputs = ["filters", "parameters", "filter_interaction", "theme_selector"] action_input_mapping = { - "filters": ( - _get_inputs_of_controls(action_id=action_id, control_type=Filter) if "filters" in include_inputs else [] - ), + "filters": (_get_inputs_of_controls(page=page, control_type=Filter) if "filters" in include_inputs else []), "parameters": ( - _get_inputs_of_controls(action_id=action_id, control_type=Parameter) - if "parameters" in include_inputs - else [] + _get_inputs_of_controls(page=page, control_type=Parameter) if "parameters" in include_inputs else [] ), # TODO: Probably need to adjust other inputs to follow the same structure List[Dict[str, State]] "filter_interaction": ( - _get_inputs_of_figure_interactions(action_id=action_id, action_function=filter_interaction.__wrapped__) + _get_inputs_of_figure_interactions(page=page, action_function=filter_interaction.__wrapped__) if "filter_interaction" in include_inputs else [] ), @@ -177,9 +113,6 @@ def _get_action_callback_outputs(action_id: ModelID) -> Dict[str, Output]: if action_function == _parameter.__wrapped__: targets = [target.split(".")[0] for target in targets] - if action_function == _on_page_load.__wrapped__: - targets = _get_components_with_data(action_id=action_id) - return { target: Output( component_id=target, @@ -200,7 +133,7 @@ def _get_export_data_callback_outputs(action_id: ModelID) -> Dict[str, List[Stat targets = None if not targets: - targets = _get_components_with_data(action_id=action_id) + targets = model_manager._get_model_page(model_id=action_id)._get_page_model_ids_with_figure() return { f"download_dataframe_{target}": Output( @@ -226,7 +159,7 @@ def _get_export_data_callback_components(action_id: ModelID) -> List[dcc.Downloa targets = None if not targets: - targets = _get_components_with_data(action_id=action_id) + targets = model_manager._get_model_page(model_id=action_id)._get_page_model_ids_with_figure() return [ dcc.Download( diff --git a/vizro-core/src/vizro/actions/_on_page_load_action.py b/vizro-core/src/vizro/actions/_on_page_load_action.py index 0010a660e..63d26d9d8 100644 --- a/vizro-core/src/vizro/actions/_on_page_load_action.py +++ b/vizro-core/src/vizro/actions/_on_page_load_action.py @@ -1,35 +1,28 @@ """Pre-defined action function "_on_page_load" to be reused in `action` parameter of VizroBaseModels.""" -from typing import Any, Dict +from typing import Any, Dict, List from dash import ctx from vizro.actions._actions_utils import ( _get_modified_page_figures, ) -from vizro.managers import data_manager, model_manager from vizro.managers._model_manager import ModelID from vizro.models.types import capture @capture("action") -def _on_page_load(page_id: ModelID, **inputs: Dict[str, Any]) -> Dict[str, Any]: +def _on_page_load(targets: List[ModelID], **inputs: Dict[str, Any]) -> Dict[str, Any]: """Applies controls to charts on page once the page is opened (or refreshed). Args: - page_id: Page ID of relevant page + targets: List of target component ids to apply on page load mechanism to inputs: Dict mapping action function names with their inputs e.g. inputs = {'filters': [], 'parameters': ['gdpPercap'], 'filter_interaction': [], 'theme_selector': True} Returns: Dict mapping target chart ids to modified figures e.g. {'my_scatter': Figure({})} """ - targets = [ - component.id - for component in model_manager[page_id].components - if data_manager._has_registered_data(component.id) - ] - return _get_modified_page_figures( targets=targets, ctds_filter=ctx.args_grouping["external"]["filters"], diff --git a/vizro-core/src/vizro/managers/_data_manager.py b/vizro-core/src/vizro/managers/_data_manager.py index 355062f95..7f70fccde 100644 --- a/vizro-core/src/vizro/managers/_data_manager.py +++ b/vizro-core/src/vizro/managers/_data_manager.py @@ -74,13 +74,6 @@ def _get_component_data(self, component_id: ComponentID) -> pd.DataFrame: # to not do any inplace=True operations, but probably safest to leave it here. return self.__original_data[dataset_name].copy() - def _has_registered_data(self, component_id: ComponentID) -> bool: - try: - self._get_component_data(component_id) - return True - except KeyError: - return False - def _clear(self): self.__init__() # type: ignore[misc] diff --git a/vizro-core/src/vizro/managers/_model_manager.py b/vizro-core/src/vizro/managers/_model_manager.py index cc1b9984d..05f14bd34 100644 --- a/vizro-core/src/vizro/managers/_model_manager.py +++ b/vizro-core/src/vizro/managers/_model_manager.py @@ -3,12 +3,12 @@ import random import uuid -from typing import TYPE_CHECKING, Dict, Generator, NewType, Tuple, Type, TypeVar, cast +from typing import TYPE_CHECKING, Dict, Generator, List, NewType, Tuple, Type, TypeVar, cast from vizro.managers._managers_utils import _state_modifier if TYPE_CHECKING: - from vizro.models import VizroBaseModel + from vizro.models import Page, VizroBaseModel rd = random.Random(0) @@ -25,6 +25,8 @@ def __init__(self): self.__models: Dict[ModelID, VizroBaseModel] = {} self._frozen_state = False + # TODO: Consider do we need to save "page_id=None, parent_model_id=None" eagerly to the model itself + # and make all searching helper methods much easier? @_state_modifier def __setitem__(self, model_id: ModelID, model: Model): if model_id in self.__models: @@ -46,12 +48,63 @@ def __iter__(self) -> Generator[ModelID, None, None]: """ yield from self.__models + # TODO: Consider do we need to add additional argument "model_page_id=None" def _items_with_type(self, model_type: Type[Model]) -> Generator[Tuple[ModelID, Model], None, None]: """Iterates through all models of type `model_type` (including subclasses).""" for model_id in self: if isinstance(self[model_id], model_type): yield model_id, cast(Model, self[model_id]) + # TODO: Consider moving this one to the _callback_mapping_utils.py since it's only used there + def _get_action_trigger(self, action_id: ModelID) -> VizroBaseModel: # type: ignore[return] + from vizro.models._action._actions_chain import ActionsChain + + for _, actions_chain in model_manager._items_with_type(ActionsChain): + if action_id in [action.id for action in actions_chain.actions]: + return self[ModelID(str(actions_chain.trigger.component_id))] + + # TODO: consider returning with yield + def _get_model_children(self, model_id: ModelID) -> List[ModelID]: + """Gets all components and tabs recursively of the model with the `model_id`.""" + model_children = [] + + def __get_model_children_helper(model: VizroBaseModel) -> None: + model_children.append(ModelID(str(model.id))) + if hasattr(model, "components"): + for sub_model in model.components: + __get_model_children_helper(model=sub_model) + if hasattr(model, "tabs"): + for sub_model in model.tabs: + __get_model_children_helper(model=sub_model) + + __get_model_children_helper(model=self.__models[model_id]) + return model_children + + # TODO: Consider moving this one into Dashboard or some util file + def _get_model_page(self, model_id: ModelID) -> Page: # type: ignore[return] + """Gets the page id of the page that contains the model with the `model_id`.""" + from vizro.models import Page + + for page_id, _ in model_manager._items_with_type(Page): + page_model_ids = [page_id, self._get_model_children(model_id=page_id)] + page: Page = cast(Page, self.__models[page_id]) + + if hasattr(page, "actions"): + for actions_chain in page._get_page_actions_chains(): + page_model_ids.append(actions_chain.id) + for action in actions_chain.actions: + page_model_ids.append(action.id) + + for control in page.controls: + page_model_ids.append(control.id) + if hasattr(control, "selector") and control.selector: + page_model_ids.append(control.selector.id) + + # TODO: Add navigation, accordions and other page objects + + if model_id in page_model_ids: + return page + @staticmethod def _generate_id() -> ModelID: return ModelID(str(uuid.UUID(int=rd.getrandbits(128)))) diff --git a/vizro-core/src/vizro/models/_controls/filter.py b/vizro-core/src/vizro/models/_controls/filter.py index cc4f52e4d..6362b338b 100644 --- a/vizro-core/src/vizro/models/_controls/filter.py +++ b/vizro-core/src/vizro/models/_controls/filter.py @@ -25,7 +25,7 @@ from vizro.models.types import MultiValueType, SelectorType if TYPE_CHECKING: - from vizro.models import Page + pass from vizro.managers._model_manager import ModelID @@ -45,14 +45,6 @@ def _filter_isin(series: pd.Series, value: MultiValueType) -> pd.Series: return series.isin(value) -def _get_component_page(component_id: str) -> Page: # type: ignore[return] - from vizro.models import Page - - for page_id, page in model_manager._items_with_type(Page): - if any(control.id == component_id for control in page.controls): - return page - - class Filter(VizroBaseModel): """Filter the data supplied to `targets` on the [`Page`][vizro.models.Page]. @@ -98,11 +90,12 @@ def build(self): def _set_targets(self): if not self.targets: - for component in _get_component_page(str(self.id)).components: - if data_manager._has_registered_data(component.id): - data_frame = data_manager._get_component_data(component.id) - if self.column in data_frame.columns: - self.targets.append(component.id) + for component_id in model_manager._get_model_page( + model_id=ModelID(str(self.id)) + )._get_page_model_ids_with_figure(): + data_frame = data_manager._get_component_data(component_id) + if self.column in data_frame.columns: + self.targets.append(component_id) if not self.targets: raise ValueError(f"Selected column {self.column} not found in any dataframe on this page.") diff --git a/vizro-core/src/vizro/models/_page.py b/vizro-core/src/vizro/models/_page.py index e4e9a5f04..e187f08b4 100644 --- a/vizro-core/src/vizro/models/_page.py +++ b/vizro-core/src/vizro/models/_page.py @@ -11,7 +11,8 @@ from vizro._constants import ON_PAGE_LOAD_ACTION_PREFIX from vizro.actions import _on_page_load -from vizro.managers._model_manager import DuplicateIDError +from vizro.managers import model_manager +from vizro.managers._model_manager import DuplicateIDError, ModelID from vizro.models import Action, Layout, VizroBaseModel from vizro.models._action._actions_chain import ActionsChain, Trigger from vizro.models._models_utils import _log_call, get_unique_grid_component_ids @@ -104,10 +105,34 @@ def __init__(self, **data): f"as the page title. If you have multiple pages with the same title then you must assign a unique id." ) from exc + def _get_page_actions_chains(self) -> List[ActionsChain]: + """Gets all ActionsChains present on the page.""" + page_actions_chains = [] + + for model_id in model_manager._get_model_children(model_id=ModelID(str(self.id))): + model = model_manager[model_id] + if hasattr(model, "actions"): + page_actions_chains.extend(model.actions) + + for control in self.controls: + if hasattr(control, "selector") and control.selector: + page_actions_chains.extend(control.selector.actions) + + return page_actions_chains + + def _get_page_model_ids_with_figure(self) -> List[ModelID]: + """Gets all components that have a registered dataframe on the page.""" + return [ + model_id + for model_id in model_manager._get_model_children(model_id=ModelID(str(self.id))) + if hasattr(model_manager[model_id], "figure") + ] + @_log_call def pre_build(self): # TODO: Remove default on page load action if possible - if any(hasattr(component, "figure") for component in self.components): + targets = list(self._get_page_model_ids_with_figure()) + if targets: self.actions = [ ActionsChain( id=f"{ON_PAGE_LOAD_ACTION_PREFIX}_{self.id}", @@ -117,7 +142,8 @@ def pre_build(self): ), actions=[ Action( - id=f"{ON_PAGE_LOAD_ACTION_PREFIX}_action_{self.id}", function=_on_page_load(page_id=self.id) + id=f"{ON_PAGE_LOAD_ACTION_PREFIX}_action_{self.id}", + function=_on_page_load(targets=targets), ) ], ) @@ -159,7 +185,11 @@ def _update_graph_theme(self): # TODO: if we do this then we should *consider* defining the callback in Graph itself rather than at Page # level. This would mean multiple callbacks on one page but if it's clientside that probably doesn't matter. - themed_components = [component for component in self.components if hasattr(component, "_update_theme")] + themed_components = [ + model_manager[model_id] + for model_id in model_manager._get_model_children(model_id=ModelID(str(self.id))) + if hasattr(model_manager[model_id], "_update_theme") + ] if themed_components: @callback( diff --git a/vizro-core/tests/unit/vizro/models/test_page.py b/vizro-core/tests/unit/vizro/models/test_page.py index 3a1e34456..5c73910f2 100644 --- a/vizro-core/tests/unit/vizro/models/test_page.py +++ b/vizro-core/tests/unit/vizro/models/test_page.py @@ -110,7 +110,7 @@ def test_valid_component_types(self, standard_px_chart, standard_dash_table): [vm.Checklist(), vm.Dropdown(), vm.RadioItems(), vm.RangeSlider(), vm.Slider()], ) def test_invalid_component_types(self, test_component): - with pytest.raises(ValidationError, match=re.escape("(allowed values: 'button', 'card', 'graph', 'table')")): + with pytest.raises(ValidationError, match=re.escape("(allowed values: 'button', 'card', 'graph', 'table', 'tabs', 'container')")): vm.Page(title="Page Title", components=[test_component]) def test_valid_control_types(self, standard_px_chart):