Skip to content

Commit

Permalink
Add unit tests for Action model (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
petar-qb authored Oct 9, 2023
1 parent 0cbe19f commit 572ace5
Show file tree
Hide file tree
Showing 12 changed files with 429 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Removed
- A bullet item for the Removed category.
-->
<!--
### Added
- A bullet item for the Added category.
-->
<!--
### Changed
- A bullet item for the Changed category.
-->
<!--
### Deprecated
- A bullet item for the Deprecated category.
-->
<!--
### Fixed
- A bullet item for the Fixed category.
-->
<!--
### Security
- A bullet item for the Security category.
-->
2 changes: 1 addition & 1 deletion vizro-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ exclude_lines = [
"if __name__ == .__main__.:",
"if TYPE_CHECKING:"
]
fail_under = 82
fail_under = 86
show_missing = true
skip_covered = true

Expand Down
9 changes: 4 additions & 5 deletions vizro-core/src/vizro/actions/_action_loop/_action_loop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""The action loop creates all the required action callbacks and its components."""

from typing import List, Union

from dash import dcc, html
from dash import html

from vizro.actions._action_loop._action_loop_utils import _get_actions_on_registered_pages
from vizro.actions._action_loop._build_action_loop_callbacks import _build_action_loop_callbacks
Expand All @@ -11,13 +10,13 @@

class ActionLoop:
@classmethod
def _create_app_callbacks(cls) -> List[Union[dcc.Store, html.Div, dcc.Download]]:
def _create_app_callbacks(cls) -> html.Div:
"""Builds callbacks for the action loop and for each Action in the Dashboard and returns their components.
Returns:
List of required components for the action loop and for each `Action` in the `Dashboard`.
"""
return cls._build_action_loop() + cls._build_actions_models()
return html.Div([cls._build_action_loop(), cls._build_actions_models()], id="app_components_div")

@staticmethod
def _build_action_loop():
Expand All @@ -37,4 +36,4 @@ def _build_actions_models():
List of required components for each `Action` in the `Dashboard` e.g. List[dcc.Download]
"""
actions = _get_actions_on_registered_pages()
return [action_component for action in actions for action_component in action.build()]
return html.Div([action.build() for action in actions], id="app_action_models_components_div")
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Contains utilities to create required components for the action loop."""
from typing import List, Union

from dash import dcc, html

Expand All @@ -9,7 +8,7 @@
)


def _get_action_loop_components() -> List[Union[dcc.Store, html.Div]]:
def _get_action_loop_components() -> html.Div:
"""Gets all required components for the action loop.
Returns:
Expand All @@ -19,7 +18,7 @@ def _get_action_loop_components() -> List[Union[dcc.Store, html.Div]]:
actions = _get_actions_on_registered_pages()

if not actions_chains:
return []
return html.Div(id="action_loop_components_div")

# Fundamental components required for the smooth operation of the action loop mechanism.
components = [
Expand Down Expand Up @@ -65,4 +64,4 @@ def _get_action_loop_components() -> List[Union[dcc.Store, html.Div]]:
)
)

return components
return html.Div(components, id="action_loop_components_div")
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _get_inputs_of_chart_interactions(
)
return [
State(
component_id=_get_triggered_model(action_id=action.id).id, # type: ignore[arg-type]
component_id=_get_triggered_model(action_id=ModelID(str(action.id))).id,
component_property="clickData",
)
for action in chart_interactions_on_page
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Creates action_callback_mapping to map callback arguments to action functions."""

from typing import Any, Dict, Union
from typing import Any, Dict, List, Union

from dash import Input, Output, State
from dash import dcc
from dash.dependencies import DashDependency

from vizro.actions import export_data, filter_interaction
from vizro.actions._callback_mapping._callback_mapping_utils import (
Expand All @@ -18,7 +19,9 @@
from vizro.managers._model_manager import ModelID


def _get_action_callback_mapping(action_id: ModelID, argument: str) -> Union[Dict[str, Union[Input, State, Output]]]:
def _get_action_callback_mapping(
action_id: ModelID, argument: str
) -> Union[List[dcc.Download], Dict[str, DashDependency]]:
"""Creates mapping of action name and required callback input/output."""
action_function = model_manager[action_id].function._function # type: ignore[attr-defined]

Expand All @@ -42,5 +45,6 @@ def _get_action_callback_mapping(action_id: ModelID, argument: str) -> Union[Dic
},
_on_page_load.__wrapped__: {"inputs": _get_action_callback_inputs, "outputs": _get_action_callback_outputs},
}
action_call = action_callback_mapping.get(action_function, {}).get(argument, {})
return action_call if isinstance(action_call, dict) else action_call(action_id=action_id)
action_call = action_callback_mapping.get(action_function, {}).get(argument)
default_value: Union[List[dcc.Download], Dict[str, DashDependency]] = [] if argument == "components" else {}
return default_value if not action_call else action_call(action_id=action_id)
98 changes: 77 additions & 21 deletions vizro-core/src/vizro/models/_action/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import logging
from typing import Any, Dict, List

from dash import Input, Output, State, callback, ctx
from dash import Input, Output, State, callback, ctx, html
from pydantic import Field, validator

import vizro.actions
from vizro.managers._model_manager import ModelID
from vizro.models import VizroBaseModel
from vizro.models._models_utils import _log_call
from vizro.models.types import CapturedCallable
Expand Down Expand Up @@ -54,12 +55,35 @@ def validate_predefined_actions(cls, function):
)
return function

@_log_call
def build(self):
@staticmethod
def _validate_output_number(outputs, return_value):
return_value_len = (
1 if not hasattr(return_value, "__len__") or isinstance(return_value, str) else len(return_value)
)

# Raising the custom exception if the callback return value length doesn't match the number of defined outputs.
if len(outputs) != return_value_len:
raise ValueError(
f"Number of action's returned elements ({return_value_len}) does not match the number"
f" of action's defined outputs ({len(outputs)})."
)

def _get_callback_mapping(self):
"""Builds callback inputs and outputs for the Action model callback, and returns action required components.
callback_inputs, and callback_outputs are "dash.State" and "dash.Output" objects made of three parts:
1. User configured inputs/outputs - for custom actions,
2. Vizro configured inputs/outputs - for predefined actions,
3. Hardcoded inputs/outputs - for custom and predefined actions
(enable callbacks to live inside the Action loop).
Returns: List of required components (e.g. dcc.Download) for the Action model added to the `Dashboard`
container. Those components represent the return value of the Action build method.
"""
from vizro.actions._callback_mapping._get_action_callback_mapping import _get_action_callback_mapping

callback_inputs: Dict[str, Any] = {
**_get_action_callback_mapping(action_id=self.id, argument="inputs"), # type: ignore[arg-type]
**_get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="inputs"),
**{
f'{input.split(".")[0]}_{input.split(".")[1]}': State(input.split(".")[0], input.split(".")[1])
for input in self.inputs
Expand All @@ -68,7 +92,7 @@ def build(self):
}

callback_outputs: Dict[str, Any] = {
**_get_action_callback_mapping(action_id=self.id, argument="outputs"), # type: ignore[arg-type]
**_get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="outputs"),
**{
f'{output.split(".")[0]}_{output.split(".")[1]}': Output(
output.split(".")[0], output.split(".")[1], allow_duplicate=True
Expand All @@ -78,6 +102,46 @@ def build(self):
"action_finished": Output("action_finished", "data", allow_duplicate=True),
}

action_components = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="components")

return callback_inputs, callback_outputs, action_components

def _action_callback_function(self, **inputs: Dict[str, Any]) -> Dict[str, Any]:
logger.debug("=============== ACTION ===============")
logger.debug(f'Action ID: "{self.id}"')
logger.debug(f'Action name: "{self.function._function.__name__}"')
logger.debug(f"Action inputs: {inputs}")

# Invoking the action's function
return_value = self.function(**inputs) or {}

# Action callback outputs
outputs = list(ctx.outputs_grouping.keys())
outputs.remove("action_finished")

# Validate number of outputs
self._validate_output_number(
outputs=outputs,
return_value=return_value,
)

# If return_value is a single element, ensure return_value is a list
if not isinstance(return_value, (list, tuple, dict)):
return_value = [return_value]
if isinstance(return_value, dict):
return {"action_finished": None, **return_value}

return {"action_finished": None, **dict(zip(outputs, return_value))}

@_log_call
def build(self):
"""Builds a callback for the Action model and returns required components for the callback.
Returns:
List of required components (e.g. dcc.Download) for the Action model added to the `Dashboard` container.
"""
callback_inputs, callback_outputs, action_components = self._get_callback_mapping()

logger.debug(
f"Creating Callback mapping for Action ID {self.id} with "
f"function name: {self.function._function.__name__}"
Expand All @@ -91,19 +155,11 @@ def build(self):
logger.debug("============================")

@callback(output=callback_outputs, inputs=callback_inputs, prevent_initial_call=True)
def callback_wrapper(trigger: None, **inputs: Dict[str, Any]):
logger.debug("=============== ACTION ===============")
logger.debug(f'Action ID: "{self.id}"')
logger.debug(f'Action name: "{self.function._function.__name__}"')
logger.debug(f"Action inputs: {inputs}")
return_value = self.function(**inputs) or {}
if isinstance(return_value, dict):
return {"action_finished": None, **return_value}

if not isinstance(return_value, list) and not isinstance(return_value, tuple):
return_value = [return_value]

# Map returned values to dictionary format where None belongs to the "action_finished" output
return dict(zip(ctx.outputs_grouping.keys(), [None, *return_value]))

return _get_action_callback_mapping(action_id=self.id, argument="components") # type: ignore[arg-type]
def callback_wrapper(trigger: None, **inputs: Dict[str, Any]) -> Dict[str, Any]:
return self._action_callback_function(**inputs)

# return action_components
return html.Div(
children=action_components,
id=f"{self.id}_action_model_components_div",
)
2 changes: 1 addition & 1 deletion vizro-core/src/vizro/models/_components/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def process_figure_data_frame(cls, figure, values):

# Convenience wrapper/syntactic sugar.
def __call__(self, **kwargs):
kwargs.setdefault("data_frame", data_manager._get_component_data(self.id)) # type: ignore[arg-type]
kwargs.setdefault("data_frame", data_manager._get_component_data(str(self.id)))
fig = self.figure(**kwargs)

# Remove top margin if title is provided
Expand Down
2 changes: 1 addition & 1 deletion vizro-core/src/vizro/models/_controls/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def build(self):

def _set_targets(self):
if not self.targets:
for component in _get_component_page(self.id).components: # type: ignore[arg-type]
for component in _get_component_page(str(self.id)).components:
if data_manager._has_registered_data(component.id):
data_frame = data_manager._get_component_data(component.id)
if self.column in data_frame.columns:
Expand Down
2 changes: 1 addition & 1 deletion vizro-core/src/vizro/models/_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def build(self):
id="dashboard_container",
children=[
html.Div(id=f"vizro_version_{vizro.__version__}"),
*ActionLoop._create_app_callbacks(),
ActionLoop._create_app_callbacks(),
dash.page_container,
],
className=self.theme,
Expand Down
Loading

0 comments on commit 572ace5

Please sign in to comment.