diff --git a/vizro-core/changelog.d/20241105_170003_antony.milne_new_interaction.md b/vizro-core/changelog.d/20241105_170003_antony.milne_new_interaction.md new file mode 100644 index 000000000..7c0d58d4f --- /dev/null +++ b/vizro-core/changelog.d/20241105_170003_antony.milne_new_interaction.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-core/changelog.d/20241106_104745_antony.milne_dynamic_filter_2.md b/vizro-core/changelog.d/20241106_104745_antony.milne_dynamic_filter_2.md new file mode 100644 index 000000000..6108280fd --- /dev/null +++ b/vizro-core/changelog.d/20241106_104745_antony.milne_dynamic_filter_2.md @@ -0,0 +1,47 @@ + + + + + + +### Changed + +- Improve performance of data loading. ([#850](https://github.com/mckinsey/vizro/pull/850), [#857](https://github.com/mckinsey/vizro/pull/857)) + + + + diff --git a/vizro-core/src/vizro/actions/_actions_utils.py b/vizro-core/src/vizro/actions/_actions_utils.py index a4d6a78c0..484d41ed4 100644 --- a/vizro-core/src/vizro/actions/_actions_utils.py +++ b/vizro-core/src/vizro/actions/_actions_utils.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections import defaultdict from copy import deepcopy from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union @@ -10,6 +9,7 @@ from vizro._constants import ALL_OPTION, NONE_OPTION from vizro.managers import data_manager, model_manager +from vizro.managers._data_manager import DataSourceName from vizro.managers._model_manager import ModelID from vizro.models.types import MultiValueType, SelectorType, SingleValueType @@ -23,7 +23,7 @@ class CallbackTriggerDict(TypedDict): """Represent dash.ctx.args_grouping item. Shortened as 'ctd' in the code. Args: - id: The component ID. If it`s a pattern matching ID, it will be a dict. + id: The component ID. If it's a pattern matching ID, it will be a dict. property: The component property used in the callback. value: The value of the component property at the time the callback was fired. str_id: For pattern matching IDs, it's the stringified dict ID without white spaces. @@ -47,7 +47,18 @@ def _get_component_actions(component) -> list[Action]: ) -def _apply_filters(data_frame: pd.DataFrame, ctds_filters: list[CallbackTriggerDict], target: str) -> pd.DataFrame: +def _apply_filter_controls( + data_frame: pd.DataFrame, ctds_filters: list[CallbackTriggerDict], target: ModelID +) -> pd.DataFrame: + """Applies filters from a vm.Filter model in the controls. + + Args: + data_frame: unfiltered DataFrame. + ctds_filters: list of CallbackTriggerDict for filters. + target: id of targeted Figure. + + Returns: filtered DataFrame. + """ for ctd in ctds_filters: selector_value = ctd["value"] selector_value = selector_value if isinstance(selector_value, list) else [selector_value] @@ -84,8 +95,19 @@ def _get_parent_vizro_model(_underlying_callable_object_id: str) -> VizroBaseMod def _apply_filter_interaction( - data_frame: pd.DataFrame, ctds_filter_interaction: list[dict[str, CallbackTriggerDict]], target: str + data_frame: pd.DataFrame, ctds_filter_interaction: list[dict[str, CallbackTriggerDict]], target: ModelID ) -> pd.DataFrame: + """Applies filters from a filter_interaction. + + This will be removed in future when filter interactions are implemented using controls. + + Args: + data_frame: unfiltered DataFrame. + ctds_filter_interaction: structure containing CallbackTriggerDict for filter interactions. + target: id of targeted Figure. + + Returns: filtered DataFrame. + """ for ctd_filter_interaction in ctds_filter_interaction: triggered_model = model_manager[ctd_filter_interaction["modelID"]["id"]] data_frame = triggered_model._filter_interaction( @@ -105,120 +127,150 @@ def _validate_selector_value_none(value: Union[SingleValueType, MultiValueType]) return value -def _create_target_arg_mapping(dot_separated_strings: list[str]) -> dict[str, list[str]]: - results = defaultdict(list) - for string in dot_separated_strings: - if "." not in string: - raise ValueError(f"Provided string {string} must contain a '.'") - component, arg = string.split(".", 1) - results[component].append(arg) - return results +def _get_target_dot_separated_strings(dot_separated_strings: list[str], target: ModelID, data_frame: bool) -> list[str]: + """Filters list of dot separated strings to get just those relevant for a single target. + Args: + dot_separated_strings: list of dot separated strings that can be targeted by a vm.Parameter, + e.g. ["target_name.data_frame.arg", "target_name.x"] + target: id of targeted Figure. + data_frame: whether to return only DataFrame parameters starting "data_frame." or only non-DataFrame parameters. -def _update_nested_graph_properties( - graph_config: dict[str, Any], dot_separated_string: str, value: Any + Returns: + List of dot separated strings for target. + """ + result = [] + + for dot_separated_string_with_target in dot_separated_strings: + if dot_separated_string_with_target.startswith(f"{target}."): + dot_separated_string = dot_separated_string_with_target.removeprefix(f"{target}.") + # We only want data_frame parameters when data_frame = True. + if dot_separated_string.startswith("data_frame.") == data_frame: + result.append(dot_separated_string) + return result + + +def _update_nested_figure_properties( + figure_config: dict[str, Any], dot_separated_string: str, value: Any ) -> dict[str, Any]: keys = dot_separated_string.split(".") - current_property = graph_config + current_property = figure_config for key in keys[:-1]: current_property = current_property.setdefault(key, {}) current_property[keys[-1]] = value - return graph_config + return figure_config + +def _get_parametrized_config( + ctd_parameters: list[CallbackTriggerDict], target: ModelID, data_frame: bool +) -> dict[str, Any]: + """Convert parameters into a keyword-argument dictionary. -def _get_parametrized_config(target: ModelID, ctd_parameters: list[CallbackTriggerDict]) -> dict[str, Any]: - # TODO - avoid calling _captured_callable. Once we have done this we can remove _arguments from - # CapturedCallable entirely. - config = deepcopy(model_manager[target].figure._arguments) + Args: + ctd_parameters: list of CallbackTriggerDicts for vm.Parameter. + target: id of targeted figure. + data_frame: whether to return only DataFrame parameters starting "data_frame." or only non-DataFrame parameters. - # It's not possible to address nested argument of data_frame like data_frame.x.y, just top-level ones like - # data_frame.x. - config["data_frame"] = {} + Returns: keyword-argument dictionary. + + """ + if data_frame: + # This entry is inserted (but will always be empty) even for static data so that the load/_multi_load calls + # look identical for dynamic data with no arguments and static data. Note it's not possible to address nested + # argument of data_frame like data_frame.x.y, just top-level ones like data_frame.x. + config: dict[str, Any] = {"data_frame": {}} + else: + # TODO - avoid calling _captured_callable. Once we have done this we can remove _arguments from + # CapturedCallable entirely. This might mean not being able to address nested parameters. + config = deepcopy(model_manager[target].figure._arguments) + del config["data_frame"] for ctd in ctd_parameters: # TODO: needs to be refactored so that it is independent of implementation details - selector_value = ctd["value"] + parameter_value = ctd["value"] - if hasattr(selector_value, "__iter__") and ALL_OPTION in selector_value: # type: ignore[operator] - selector: SelectorType = model_manager[ctd["id"]] - - # Even if options are provided as list[dict], the Dash component only returns a list of values. + selector: SelectorType = model_manager[ctd["id"]] + if hasattr(parameter_value, "__iter__") and ALL_OPTION in parameter_value: # type: ignore[operator] + # Even if an option is provided as list[dict], the Dash component only returns a list of values. # So we need to ensure that we always return a list only as well to provide consistent types. - if all(isinstance(option, dict) for option in selector.options): - selector_value = [option["value"] for option in selector.options] - else: - selector_value = selector.options + parameter_value = [option["value"] if isinstance(option, dict) else option for option in selector.options] - selector_value = _validate_selector_value_none(selector_value) - selector_actions = _get_component_actions(model_manager[ctd["id"]]) + parameter_value = _validate_selector_value_none(parameter_value) - for action in selector_actions: + for action in _get_component_actions(selector): if action.function._function.__name__ != "_parameter": continue - action_targets = _create_target_arg_mapping(action.function["targets"]) - - if target not in action_targets: - continue - - for action_targets_arg in action_targets[target]: - config = _update_nested_graph_properties( - graph_config=config, dot_separated_string=action_targets_arg, value=selector_value + for dot_separated_string in _get_target_dot_separated_strings( + action.function["targets"], target, data_frame + ): + config = _update_nested_figure_properties( + figure_config=config, dot_separated_string=dot_separated_string, value=parameter_value ) return config # Helper functions used in pre-defined actions ---- -def _get_targets_data_and_config( +def _apply_filters( + data: pd.DataFrame, ctds_filter: list[CallbackTriggerDict], ctds_filter_interaction: list[dict[str, CallbackTriggerDict]], - ctds_parameters: list[CallbackTriggerDict], - targets: list[ModelID], + target: ModelID, ): - all_filtered_data = {} - all_parameterized_config = {} - + # Takes in just one target, so dataframe is filtered repeatedly for every target that uses it. + # Potentially this could be de-duplicated but it's not so important since filtering is a relatively fast + # operation (compared to data loading). + filtered_data = _apply_filter_controls(data_frame=data, ctds_filters=ctds_filter, target=target) + filtered_data = _apply_filter_interaction( + data_frame=filtered_data, ctds_filter_interaction=ctds_filter_interaction, target=target + ) + return filtered_data + + +def _get_unfiltered_data( + ctds_parameters: list[CallbackTriggerDict], targets: list[ModelID] +) -> dict[ModelID, pd.DataFrame]: + # Takes in multiple targets to ensure that data can be loaded efficiently using _multi_load and not repeated for + # every single target. + # Getting unfiltered data requires data frame parameters. We pass in all ctd_parameters and then find the + # data_frame ones by passing data_frame=True in the call to _get_paramaterized_config. Static data is also + # handled here and will just have empty dictionary for its kwargs. + multi_data_source_name_load_kwargs: list[tuple[DataSourceName, dict[str, Any]]] = [] for target in targets: - # parametrized_config includes a key "data_frame" that is used in the data loading function. - parameterized_config = _get_parametrized_config(target=target, ctd_parameters=ctds_parameters) - data_source_name = model_manager[target]["data_frame"] - data_frame = data_manager[data_source_name].load(**parameterized_config["data_frame"]) - - filtered_data = _apply_filters(data_frame=data_frame, ctds_filters=ctds_filter, target=target) - filtered_data = _apply_filter_interaction( - data_frame=filtered_data, ctds_filter_interaction=ctds_filter_interaction, target=target + dynamic_data_load_params = _get_parametrized_config( + ctd_parameters=ctds_parameters, target=target, data_frame=True ) + data_source_name = model_manager[target]["data_frame"] + multi_data_source_name_load_kwargs.append((data_source_name, dynamic_data_load_params["data_frame"])) - # Parameters affecting data_frame have already been used above in data loading and so are excluded from - # all_parameterized_config. - all_filtered_data[target] = filtered_data - all_parameterized_config[target] = { - key: value for key, value in parameterized_config.items() if key != "data_frame" - } - - return all_filtered_data, all_parameterized_config + return dict(zip(targets, data_manager._multi_load(multi_data_source_name_load_kwargs))) def _get_modified_page_figures( ctds_filter: list[CallbackTriggerDict], ctds_filter_interaction: list[dict[str, CallbackTriggerDict]], ctds_parameters: list[CallbackTriggerDict], - targets: Optional[list[ModelID]] = None, -) -> dict[str, Any]: - targets = targets or [] - - filtered_data, parameterized_config = _get_targets_data_and_config( - ctds_filter=ctds_filter, - ctds_filter_interaction=ctds_filter_interaction, - ctds_parameters=ctds_parameters, - targets=targets, - ) + targets: list[ModelID], +) -> dict[ModelID, Any]: + outputs: dict[ModelID, Any] = {} + + # TODO: the structure here would be nicer if we could get just the ctds for a single target at one time, + # so you could do apply_filters on a target a pass only the ctds relevant for that target. + # Consider restructuring ctds to a more convenient form to make this possible. + + for target, unfiltered_data in _get_unfiltered_data(ctds_parameters, targets).items(): + filtered_data = _apply_filters(unfiltered_data, ctds_filter, ctds_filter_interaction, target) + outputs[target] = model_manager[target]( + data_frame=filtered_data, + **_get_parametrized_config(ctd_parameters=ctds_parameters, target=target, data_frame=False), + ) - outputs: dict[str, Any] = {} - for target in targets: - outputs[target] = model_manager[target](data_frame=filtered_data[target], **parameterized_config[target]) + # TODO NEXT: will need to pass unfiltered_data into Filter.__call__. + # This dictionary is filtered for correct targets already selected in Filter.__call__ or that could be done here + # instead. + # {target: data_frame for target, data_frame in unfiltered_data.items() if target in self.targets} return outputs diff --git a/vizro-core/src/vizro/actions/_filter_action.py b/vizro-core/src/vizro/actions/_filter_action.py index 409428e01..f3ec21b37 100644 --- a/vizro-core/src/vizro/actions/_filter_action.py +++ b/vizro-core/src/vizro/actions/_filter_action.py @@ -16,7 +16,7 @@ def _filter( targets: list[ModelID], filter_function: Callable[[pd.Series, Any], pd.Series], **inputs: dict[str, Any], -) -> dict[str, Any]: +) -> dict[ModelID, Any]: """Filters targeted charts/components on page by interaction with `Filter` control. Args: @@ -28,11 +28,10 @@ def _filter( Returns: Dict mapping target component ids to modified charts/components e.g. {'my_scatter': Figure({})} - """ return _get_modified_page_figures( - targets=targets, ctds_filter=ctx.args_grouping["external"]["filters"], ctds_filter_interaction=ctx.args_grouping["external"]["filter_interaction"], ctds_parameters=ctx.args_grouping["external"]["parameters"], + targets=targets, ) diff --git a/vizro-core/src/vizro/actions/_on_page_load_action.py b/vizro-core/src/vizro/actions/_on_page_load_action.py index 5b2d97cdb..306ed9b5e 100644 --- a/vizro-core/src/vizro/actions/_on_page_load_action.py +++ b/vizro-core/src/vizro/actions/_on_page_load_action.py @@ -10,7 +10,7 @@ @capture("action") -def _on_page_load(targets: list[ModelID], **inputs: dict[str, Any]) -> dict[str, Any]: +def _on_page_load(targets: list[ModelID], **inputs: dict[str, Any]) -> dict[ModelID, Any]: """Applies controls to charts on page once the page is opened (or refreshed). Args: @@ -23,8 +23,8 @@ def _on_page_load(targets: list[ModelID], **inputs: dict[str, Any]) -> dict[str, """ return _get_modified_page_figures( - targets=targets, ctds_filter=ctx.args_grouping["external"]["filters"], ctds_filter_interaction=ctx.args_grouping["external"]["filter_interaction"], ctds_parameters=ctx.args_grouping["external"]["parameters"], + targets=targets, ) diff --git a/vizro-core/src/vizro/actions/_parameter_action.py b/vizro-core/src/vizro/actions/_parameter_action.py index e96b136fb..6284481ec 100644 --- a/vizro-core/src/vizro/actions/_parameter_action.py +++ b/vizro-core/src/vizro/actions/_parameter_action.py @@ -10,7 +10,7 @@ @capture("action") -def _parameter(targets: list[str], **inputs: dict[str, Any]) -> dict[str, Any]: +def _parameter(targets: list[str], **inputs: dict[str, Any]) -> dict[ModelID, Any]: """Modifies parameters of targeted charts/components on page. Args: @@ -25,8 +25,8 @@ def _parameter(targets: list[str], **inputs: dict[str, Any]) -> dict[str, Any]: target_ids: list[ModelID] = [target.split(".")[0] for target in targets] # type: ignore[misc] return _get_modified_page_figures( - targets=target_ids, ctds_filter=ctx.args_grouping["external"]["filters"], ctds_filter_interaction=ctx.args_grouping["external"]["filter_interaction"], ctds_parameters=ctx.args_grouping["external"]["parameters"], + targets=target_ids, ) diff --git a/vizro-core/src/vizro/actions/export_data_action.py b/vizro-core/src/vizro/actions/export_data_action.py index fb87f419e..923639998 100644 --- a/vizro-core/src/vizro/actions/export_data_action.py +++ b/vizro-core/src/vizro/actions/export_data_action.py @@ -5,7 +5,7 @@ from dash import ctx, dcc from typing_extensions import Literal -from vizro.actions._actions_utils import _get_targets_data_and_config +from vizro.actions._actions_utils import _apply_filters, _get_unfiltered_data from vizro.managers import model_manager from vizro.managers._model_manager import ModelID from vizro.models.types import capture @@ -41,23 +41,19 @@ def export_data( if target not in model_manager: raise ValueError(f"Component '{target}' does not exist.") - data_frames, _ = _get_targets_data_and_config( - targets=targets, - ctds_filter=ctx.args_grouping["external"]["filters"], - ctds_filter_interaction=ctx.args_grouping["external"]["filter_interaction"], - ctds_parameters=ctx.args_grouping["external"]["parameters"], - ) - + ctds = ctx.args_grouping["external"] outputs = {} - for target_id in targets: + + for target, unfiltered_data in _get_unfiltered_data(ctds["parameters"], targets).items(): + filtered_data = _apply_filters(unfiltered_data, ctds["filters"], ctds["filter_interaction"], target) if file_format == "csv": - writer = data_frames[target_id].to_csv + writer = filtered_data.to_csv elif file_format == "xlsx": - writer = data_frames[target_id].to_excel + writer = filtered_data.to_excel # Invalid file_format should be caught by Action validation - outputs[f"download_dataframe_{target_id}"] = dcc.send_data_frame( - writer=writer, filename=f"{target_id}.{file_format}", index=False + outputs[f"download_dataframe_{target}"] = dcc.send_data_frame( + writer=writer, filename=f"{target}.{file_format}", index=False ) return outputs diff --git a/vizro-core/src/vizro/actions/filter_interaction_action.py b/vizro-core/src/vizro/actions/filter_interaction_action.py index 1f40f3171..bc6659ab9 100644 --- a/vizro-core/src/vizro/actions/filter_interaction_action.py +++ b/vizro-core/src/vizro/actions/filter_interaction_action.py @@ -10,7 +10,7 @@ @capture("action") -def filter_interaction(targets: Optional[list[ModelID]] = None, **inputs: dict[str, Any]) -> dict[str, Any]: +def filter_interaction(targets: Optional[list[ModelID]] = None, **inputs: dict[str, Any]) -> dict[ModelID, Any]: """Filters targeted charts/components on page by clicking on data points or table cells of the source chart. To set up filtering on specific columns of the target graph(s), include these columns in the 'custom_data' @@ -29,8 +29,8 @@ def filter_interaction(targets: Optional[list[ModelID]] = None, **inputs: dict[s """ return _get_modified_page_figures( - targets=targets, ctds_filter=ctx.args_grouping["external"]["filters"], ctds_filter_interaction=ctx.args_grouping["external"]["filter_interaction"], ctds_parameters=ctx.args_grouping["external"]["parameters"], + targets=targets or [], ) diff --git a/vizro-core/src/vizro/managers/_data_manager.py b/vizro-core/src/vizro/managers/_data_manager.py index 44e7972db..9b178e084 100644 --- a/vizro-core/src/vizro/managers/_data_manager.py +++ b/vizro-core/src/vizro/managers/_data_manager.py @@ -3,11 +3,12 @@ from __future__ import annotations import functools +import json import logging import os import warnings from functools import partial -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import pandas as pd import wrapt @@ -196,6 +197,41 @@ def __getitem__(self, name: DataSourceName) -> Union[_DynamicData, _StaticData]: except KeyError as exc: raise KeyError(f"Data source {name} does not exist.") from exc + def _multi_load(self, multi_name_load_kwargs: list[tuple[DataSourceName, dict[str, Any]]]) -> list[pd.DataFrame]: + """Loads multiple data sources as efficiently as possible. + + Deduplicates a list of (data source name, load keyword argument dictionary) tuples so that each one corresponds + to only a single load() call. In the worst case scenario where there are no repeated tuples then performance of + this function is identical to doing a load call for each tuple. + + If a data source is static then load keyword argument dictionary must be {}. + + Args: + multi_name_load_kwargs: List of (data source name, load keyword argument dictionary). + + Returns: + Loaded data in the same order as `multi_name_load_kwargs` was supplied. + """ + + # Easiest way to make a key to de-duplicate each (data source name, load keyword argument dictionary) tuple. + def encode_load_key(name, load_kwargs): + return json.dumps([name, load_kwargs], sort_keys=True) + + def decode_load_key(key): + return json.loads(key) + + # dict.fromkeys does the de-duplication. + load_key_to_data = dict.fromkeys( + encode_load_key(name, load_kwargs) for name, load_kwargs in multi_name_load_kwargs + ) + + # Load each key only once. + for load_key in load_key_to_data.keys(): + name, load_kwargs = decode_load_key(load_key) + load_key_to_data[load_key] = self[name].load(**load_kwargs) + + return [load_key_to_data[encode_load_key(name, load_kwargs)] for name, load_kwargs in multi_name_load_kwargs] + def _clear(self): # We do not actually call self.cache.clear() because (a) it would only work when self._cache_has_app is True, # which is not the case when e.g. Vizro._reset is called, and (b) because we do not want to accidentally diff --git a/vizro-core/src/vizro/models/_controls/filter.py b/vizro-core/src/vizro/models/_controls/filter.py index e70fad98c..683e7f870 100644 --- a/vizro-core/src/vizro/models/_controls/filter.py +++ b/vizro-core/src/vizro/models/_controls/filter.py @@ -6,6 +6,8 @@ import pandas as pd from pandas.api.types import is_datetime64_any_dtype, is_numeric_dtype +from vizro.managers._data_manager import DataSourceName + try: from pydantic.v1 import Field, PrivateAttr, validator except ImportError: # pragma: no cov @@ -96,19 +98,27 @@ def check_target_present(cls, target): @_log_call def pre_build(self): - if self.targets: - targeted_data = self._validate_targeted_data(targets=self.targets) - else: - # If targets aren't explicitly provided then try to target all figures on the page. In this case we don't - # want to raise an error if the column is not found in a figure's data_frame, it will just be ignored. - # Possibly in future this will change (which would be breaking change). - targeted_data = self._validate_targeted_data( - targets=model_manager._get_page_model_ids_with_figure( - page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id))) - ), - eagerly_raise_column_not_found_error=False, - ) - self.targets = list(targeted_data.columns) + # If targets aren't explicitly provided then try to target all figures on the page. In this case we don't + # want to raise an error if the column is not found in a figure's data_frame, it will just be ignored. + # This is the case when bool(self.targets) is False. + # Possibly in future this will change (which would be breaking change). + proposed_targets = self.targets or model_manager._get_page_model_ids_with_figure( + page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id))) + ) + # TODO NEXT: how to handle pre_build for dynamic filters? Do we still require default argument values in + # `load` to establish selector type etc.? Can we take selector values from model_manager to supply these? + # Or just don't do validation at pre_build time and wait until state is available during build time instead? + # What should the load kwargs be here? Remember they need to be {} for static data. + # Note that currently _get_unfiltered_data is only suitable for use at runtime since it requires + # ctd_parameters. That could be changed to just reuse that function. + multi_data_source_name_load_kwargs: list[tuple[DataSourceName, dict[str, Any]]] = [ + (model_manager[target]["data_frame"], {}) for target in proposed_targets + ] + target_to_data_frame = dict(zip(proposed_targets, data_manager._multi_load(multi_data_source_name_load_kwargs))) + targeted_data = self._validate_targeted_data( + target_to_data_frame, eagerly_raise_column_not_found_error=bool(self.targets) + ) + self.targets = list(targeted_data.columns) # Set default selector according to column type. self._column_type = self._validate_column_type(targeted_data) @@ -148,12 +158,15 @@ def pre_build(self): ) ] - def __call__(self, **kwargs): + def __call__(self, target_to_data_frame: dict[ModelID, pd.DataFrame]): # Only relevant for a dynamic filter. - # TODO: this will need to pass parametrised data_frame arguments through to _validate_targeted_data. # Although targets are fixed at build time, the validation logic is repeated during runtime, so if a column # is missing then it will raise an error. We could change this if we wanted. - targeted_data = self._validate_targeted_data(targets=self.targets) + # Call this from actions_utils + targeted_data = self._validate_targeted_data( + {target: data_frame for target, data_frame in target_to_data_frame.items() if target in self.targets}, + eagerly_raise_column_not_found_error=True, + ) if (column_type := self._validate_column_type(targeted_data)) != self._column_type: raise ValueError( @@ -173,29 +186,11 @@ def build(self): return self.selector.build() def _validate_targeted_data( - self, targets: list[ModelID], eagerly_raise_column_not_found_error=True + self, target_to_data_frame: dict[ModelID, pd.DataFrame], eagerly_raise_column_not_found_error ) -> pd.DataFrame: - # TODO: consider moving some of this logic to data_manager when implement dynamic filter. Make sure - # get_modified_figures and stuff in _actions_utils.py is as efficient as code here. - - # When loading data_frame there are possible keys: - # 1. target. In worst case scenario this is needed but can lead to unnecessary repeated data loading. - # 2. data_source_name. No repeated data loading but won't work when applying data_frame parameters at runtime. - # 3. target + data_frame parameters keyword-argument pairs. This is the correct key to use at runtime. - # For now we follow scheme 2 for data loading (due to set() below) and 1 for the returned targeted_data - # pd.DataFrame, i.e. a separate column for each target even if some data is repeated. - # TODO: when this works with data_frame parameters load() will need to take arguments and the structures here - # might change a bit. - target_to_data_source_name = {target: model_manager[target]["data_frame"] for target in targets} - data_source_name_to_data = { - data_source_name: data_manager[data_source_name].load() - for data_source_name in set(target_to_data_source_name.values()) - } target_to_series = {} - for target, data_source_name in target_to_data_source_name.items(): - data_frame = data_source_name_to_data[data_source_name] - + for target, data_frame in target_to_data_frame.items(): if self.column in data_frame.columns: # reset_index so that when we make a DataFrame out of all these pd.Series pandas doesn't try to align # the columns by index. @@ -206,10 +201,14 @@ def _validate_targeted_data( targeted_data = pd.DataFrame(target_to_series) if targeted_data.columns.empty: # Still raised when eagerly_raise_column_not_found_error=False. - raise ValueError(f"Selected column {self.column} not found in any dataframe for {', '.join(targets)}.") + raise ValueError( + f"Selected column {self.column} not found in any dataframe for " + f"{', '.join(target_to_data_frame.keys())}." + ) if targeted_data.empty: raise ValueError( - f"Selected column {self.column} does not contain anything in any dataframe for {', '.join(targets)}." + f"Selected column {self.column} does not contain anything in any dataframe for " + f"{', '.join(target_to_data_frame.keys())}." ) return targeted_data diff --git a/vizro-core/tests/unit/vizro/actions/test_actions_utils.py b/vizro-core/tests/unit/vizro/actions/test_actions_utils.py index 56e09f899..348c114a5 100644 --- a/vizro-core/tests/unit/vizro/actions/test_actions_utils.py +++ b/vizro-core/tests/unit/vizro/actions/test_actions_utils.py @@ -1,12 +1,12 @@ import pytest -from vizro.actions._actions_utils import _create_target_arg_mapping, _update_nested_graph_properties +from vizro.actions._actions_utils import _get_target_dot_separated_strings, _update_nested_figure_properties class TestUpdateNestedGraphProperties: - def test_update_nested_graph_properties_single_level(self): + def test_update_nested_figure_properties_single_level(self): graph = {"color": "blue"} - result = _update_nested_graph_properties(graph, "color", "red") + result = _update_nested_figure_properties(graph, "color", "red") expected = {"color": "red"} assert result == expected @@ -22,8 +22,8 @@ def test_update_nested_graph_properties_single_level(self): ), ], ) - def test_update_nested_graph_properties_multiple_levels(self, graph, dot_separated_strings, value, expected): - result = _update_nested_graph_properties(graph, dot_separated_strings, value) + def test_update_nested_figure_properties_multiple_levels(self, graph, dot_separated_strings, value, expected): + result = _update_nested_figure_properties(graph, dot_separated_strings, value) assert result == expected @pytest.mark.parametrize( @@ -50,47 +50,57 @@ def test_update_nested_graph_properties_multiple_levels(self, graph, dot_separat ({}, "color", "red", {"color": "red"}), ], ) - def test_update_nested_graph_properties_add_or_overwrite_keys(self, graph, dot_separated_strings, value, expected): - result = _update_nested_graph_properties(graph, dot_separated_strings, value) + def test_update_nested_figure_properties_add_or_overwrite_keys(self, graph, dot_separated_strings, value, expected): + result = _update_nested_figure_properties(graph, dot_separated_strings, value) assert result == expected - def test_update_nested_graph_properties_invalid_type(self): + def test_update_nested_figure_properties_invalid_type(self): graph = {"color": "blue"} with pytest.raises(TypeError, match="'str' object does not support item assignment"): - _update_nested_graph_properties(graph, "color.value", "42") + _update_nested_figure_properties(graph, "color.value", "42") -class TestCreateTargetArgMapping: - def test_single_string_one_component(self): - input_strings = ["component1.argument1"] - expected = {"component1": ["argument1"]} - result = _create_target_arg_mapping(input_strings) - assert result == expected - - def test_multiple_strings_different_components(self): - input_strings = ["component1.argument1", "component2.argument2", "component1.argument3"] - expected = {"component1": ["argument1", "argument3"], "component2": ["argument2"]} - result = _create_target_arg_mapping(input_strings) - assert result == expected - - def test_multiple_strings_same_component(self): - input_strings = ["component1.argument1", "component1.argument2", "component1.argument3"] - expected = {"component1": ["argument1", "argument2", "argument3"]} - result = _create_target_arg_mapping(input_strings) - assert result == expected - - def test_empty_input_list(self): - input_strings = [] - expected = {} - result = _create_target_arg_mapping(input_strings) - assert result == expected - - def test_strings_without_separator(self): - input_strings = ["component1_argument1", "component2_argument2"] - with pytest.raises(ValueError, match="must contain a '.'"): - _create_target_arg_mapping(input_strings) +class TestFilterDotSeparatedStrings: + @pytest.mark.parametrize( + "dot_separated_strings, expected", + [ + ([], []), + (["component1.argument1", "component1.data_frame.x"], ["data_frame.x"]), + ( + [ + "component1.argument1", + "component1.data_frame.x", + "component1.data_frame.y", + "component2.argument2", + "component2.data_frame.z", + "component1.argument3", + ], + ["data_frame.x", "data_frame.y"], + ), + (["component1.argument1.extra", "component1.data_frame.x"], ["data_frame.x"]), + ], + ) + def test_filter_data_frame_parameters(self, dot_separated_strings, expected): + assert _get_target_dot_separated_strings(dot_separated_strings, "component1", data_frame=True) == expected - def test_strings_with_multiple_separators(self): - input_strings = ["component1.argument1.extra", "component2.argument2.extra"] - expected = {"component1": ["argument1.extra"], "component2": ["argument2.extra"]} - assert _create_target_arg_mapping(input_strings) == expected + @pytest.mark.parametrize( + "dot_separated_strings, expected", + [ + ([], []), + (["component1.argument1", "component1.data_frame.x"], ["argument1"]), + ( + [ + "component1.argument1", + "component1.data_frame.x", + "component1.data_frame.y", + "component2.argument2", + "component2.data_frame.z", + "component1.argument3", + ], + ["argument1", "argument3"], + ), + (["component1.argument1.extra", "component1.data_frame.x"], ["argument1.extra"]), + ], + ) + def test_filter_non_data_frame_parameters(self, dot_separated_strings, expected): + assert _get_target_dot_separated_strings(dot_separated_strings, "component1", data_frame=False) == expected diff --git a/vizro-core/tests/unit/vizro/managers/test_data_manager.py b/vizro-core/tests/unit/vizro/managers/test_data_manager.py index dd6b69195..b8a53ab4b 100644 --- a/vizro-core/tests/unit/vizro/managers/test_data_manager.py +++ b/vizro-core/tests/unit/vizro/managers/test_data_manager.py @@ -13,6 +13,7 @@ from vizro import Vizro from vizro.managers import data_manager +from vizro.managers._data_manager import _DynamicData, _StaticData # Fixture that freezes the time so that tests involving time.sleep can run quickly. Instead of time.sleep, @@ -42,6 +43,10 @@ def make_fixed_data(): return pd.DataFrame([1, 2, 3]) +def make_fixed_data_with_args(label, another_label="x"): + return pd.DataFrame([1, 2, 3]).assign(label=label, another_label=another_label) + + class TestLoad: def test_static(self): data = make_fixed_data() @@ -68,6 +73,100 @@ def test_dynamic_lambda(self): assert loaded_data is not data() +class TestMultiLoad: + def test_static_single_request(self, mocker): + # Single value in multi_name_load_kwargs loads the data. + data_manager["data"] = make_fixed_data() + load_spy = mocker.spy(_StaticData, "load") + loaded_data = data_manager._multi_load([("data", {})]) + assert load_spy.call_count == 1 + assert len(loaded_data) == 1 + assert_frame_equal(loaded_data[0], make_fixed_data()) + + def test_static_multiple_requests(self, mocker): + # Multiple distinct values in multi_name_load_kwargs are loaded separately but repeated ones are not. + data_manager["data_x"] = make_fixed_data() + data_manager["data_y"] = make_fixed_data() + load_spy = mocker.spy(_StaticData, "load") + loaded_data = data_manager._multi_load([("data_x", {}), ("data_y", {}), ("data_x", {})]) + assert load_spy.call_count == 2 # Crucially this is not 3. + assert len(loaded_data) == 3 + assert_frame_equal(loaded_data[0], make_fixed_data()) + assert_frame_equal(loaded_data[1], make_fixed_data()) + assert_frame_equal(loaded_data[2], make_fixed_data()) + + # Behavior of static data and dynamic data with no arguments is the same. + def test_dynamic_single_request_no_args(self, mocker): + # Single value in multi_name_load_kwargs loads the data. + data_manager["data"] = make_fixed_data + load_spy = mocker.spy(_DynamicData, "load") + loaded_data = data_manager._multi_load([("data", {})]) + assert load_spy.call_count == 1 + assert len(loaded_data) == 1 + assert_frame_equal(loaded_data[0], make_fixed_data()) + + # Behavior of static data and dynamic data with no arguments is the same. + def test_dynamic_multiple_requests_no_args(self, mocker): + # Multiple distinct values in multi_name_load_kwargs are loaded separately but repeated ones are not. + data_manager["data_x"] = make_fixed_data + data_manager["data_y"] = make_fixed_data + load_spy = mocker.spy(_DynamicData, "load") + loaded_data = data_manager._multi_load([("data_x", {}), ("data_y", {}), ("data_x", {})]) + assert load_spy.call_count == 2 # Crucially this is not 3. + assert len(loaded_data) == 3 + assert_frame_equal(loaded_data[0], make_fixed_data()) + assert_frame_equal(loaded_data[1], make_fixed_data()) + assert_frame_equal(loaded_data[2], make_fixed_data()) + + # Test various JSON-serialisable types of argument value. + @pytest.mark.parametrize("label", ["y", None, [1, 2, 3], {"a": "b"}]) + def test_dynamic_single_request_with_args(self, label, mocker): + # Single value in multi_name_load_kwargs loads the data. + data_manager["data"] = make_fixed_data_with_args + load_spy = mocker.spy(_DynamicData, "load") + loaded_data = data_manager._multi_load([("data", {"label": label})]) + assert load_spy.call_count == 1 + assert len(loaded_data) == 1 + assert_frame_equal(loaded_data[0], make_fixed_data_with_args(label=label)) + + def test_dynamic_multiple_requests_with_args(self, mocker): + # Multiple distinct values in multi_name_load_kwargs are loaded separately but repeated ones are not. + data_manager["data_x"] = make_fixed_data_with_args + data_manager["data_y"] = make_fixed_data_with_args + load_spy = mocker.spy(_DynamicData, "load") + loaded_data = data_manager._multi_load( + [ + ("data_x", {"label": "x"}), + ("data_x", {"label": "y"}), + ("data_y", {"label": "x"}), + ("data_x", {"label": "x"}), # Repeat of first entry. + ] + ) + assert load_spy.call_count == 3 # Crucially this is not 4. + assert len(loaded_data) == 4 + assert_frame_equal(loaded_data[0], make_fixed_data_with_args(label="x")) + assert_frame_equal(loaded_data[1], make_fixed_data_with_args(label="y")) + assert_frame_equal(loaded_data[2], make_fixed_data_with_args(label="x")) + assert_frame_equal(loaded_data[3], make_fixed_data_with_args(label="x")) + + def test_dynamic_args_order_does_not_matter(self, mocker): + # Multiple distinct values in multi_name_load_kwargs are loaded separately but repeated ones are not. + data_manager["data"] = make_fixed_data_with_args + load_spy = mocker.spy(_DynamicData, "load") + loaded_data = data_manager._multi_load( + [ + ("data", {"label": "x", "another_label": "x"}), + ("data", {"label": "x", "another_label": "y"}), + ("data", {"another_label": "x", "label": "x"}), + ] + ) + assert load_spy.call_count == 2 # Crucially this is not 3. + assert len(loaded_data) == 3 + assert_frame_equal(loaded_data[0], make_fixed_data_with_args(label="x", another_label="x")) + assert_frame_equal(loaded_data[1], make_fixed_data_with_args(label="x", another_label="y")) + assert_frame_equal(loaded_data[2], make_fixed_data_with_args(label="x", another_label="x")) + + class TestInvalid: def test_static_data_does_not_support_timeout(self): data = make_fixed_data()