From 87f0fa202f5222754bde7a289a858224129399b7 Mon Sep 17 00:00:00 2001 From: Antony Milne Date: Fri, 5 Apr 2024 15:20:17 +0100 Subject: [PATCH] Make dataset name the cache key so that preload is not needed --- vizro-core/examples/_dev/app.py | 42 +++++++++------- .../src/vizro/managers/_data_manager.py | 48 ++++++++----------- .../src/vizro/managers/_model_manager.py | 2 + .../models/_components/_components_utils.py | 37 +++++++------- vizro-core/src/vizro/models/types.py | 6 +-- 5 files changed, 66 insertions(+), 69 deletions(-) diff --git a/vizro-core/examples/_dev/app.py b/vizro-core/examples/_dev/app.py index 11f21e685..cc6012700 100644 --- a/vizro-core/examples/_dev/app.py +++ b/vizro-core/examples/_dev/app.py @@ -1,33 +1,39 @@ """Example to show dashboard configuration.""" +from flask_caching import Cache import vizro.models as vm import vizro.plotly.express as px from vizro import Vizro +from vizro.managers import data_manager from vizro.tables import dash_ag_grid -df = px.data.gapminder() +from vizro.managers import data_manager + +df = px.data.iris() + +# Cache of default_expire_data expires every 5 minutes, the default +data_manager.cache = Cache(config={"CACHE_TYPE": "FileSystemCache", "CACHE_DIR": "cache", "CACHE_DEFAULT_TIMEOUT": 20}) +data_manager["default_expire_data"] = lambda: px.data.iris() + +# Set cache of fast_expire_data to expire every 10 seconds +data_manager["fast_expire_data"] = lambda: px.data.iris() +data_manager["fast_expire_data"].timeout = 5 +# Set cache of no_expire_data to never expire +data_manager["no_expire_data"] = lambda: px.data.iris() +data_manager["no_expire_data"].timeout = 0 page = vm.Page( - title="Enhanced AG Grid", + title="Blah", components=[ - vm.AgGrid( - title="Dash AG Grid", - figure=dash_ag_grid( - data_frame=df, - columnDefs=[ - {"field": "country", "floatingFilter": True, "suppressHeaderMenuButton": True}, - {"field": "continent", "floatingFilter": True, "suppressHeaderMenuButton": True}, - {"field": "year"}, - {"field": "lifeExp", "cellDataType": "numeric"}, - {"field": "pop", "cellDataType": "numeric"}, - {"field": "gdpPercap", "cellDataType": "euro"}, - ], - ), - ), + vm.Graph(figure=px.scatter(df, "sepal_width", "sepal_length")), + vm.Graph(figure=px.scatter("default_expire_data", "sepal_width", "sepal_length")), + vm.Graph(figure=px.scatter("fast_expire_data", "sepal_width", "sepal_length")), + vm.Graph(figure=px.scatter("no_expire_data", "sepal_width", "sepal_length")), ], - controls=[vm.Filter(column="continent")], ) dashboard = vm.Dashboard(pages=[page]) +app = Vizro().build(dashboard) +server = app.dash.server if __name__ == "__main__": - Vizro().build(dashboard).run() + app.run() diff --git a/vizro-core/src/vizro/managers/_data_manager.py b/vizro-core/src/vizro/managers/_data_manager.py index 117025360..238f0af1a 100644 --- a/vizro-core/src/vizro/managers/_data_manager.py +++ b/vizro-core/src/vizro/managers/_data_manager.py @@ -1,7 +1,7 @@ """The data manager handles access to all DataFrames used in a Vizro app.""" from __future__ import annotations - +import os import logging from typing import Callable, Dict, Optional, Union @@ -17,13 +17,11 @@ # * new error messages that are raised # * set cache to null in all other tests - -# * rename in code -# * make sure no mention of lazy or eager/active +# TODO: __main__ in this file: remove/move to docs logger = logging.getLogger(__name__) - +logger.setLevel(logging.DEBUG) # Really ComponentID and DataSourceName should be NewType and not just aliases but then for a user's code to type check # correctly they would need to cast all strings to these types. # TODO: remove these type aliases once have moved component to data mapping to models @@ -87,13 +85,16 @@ class _DynamicData: Possibly in future, this will become a public class so you could directly do: >>> data_manager["dynamic_data"] = DynamicData(dynamic_data, timeout=5) + But we'd need to make sure that name is not an argument in __init__ then. At this point we might like to disable the behaviour so that data_manager setitem and getitem handle the same object rather than doing an implicit conversion to _DynamicData. """ - def __init__(self, load_data: pd_DataFrameCallable, /): + def __init__(self, name: str, load_data: pd_DataFrameCallable): self.__load_data: pd_DataFrameCallable = load_data + # name is needed for the cache key and should not be modified by users. + self._name = name self.timeout: Optional[int] = None # We might also want a self.cache_arguments dictionary in future that allows user to customise more than just # timeout, but no rush to do this since other arguments are unlikely to be useful. @@ -109,30 +110,20 @@ def load(self) -> pd.DataFrame: return self.__load_data() def __repr__(self): - """This is just the default repr so behaviour would be the same if we removed the function definition. + """Flask-caching uses repr to form the cache key, so this is very important to set correctly. - The reason for defining this is to have somewhere to put the following warning: caching currently relies on - this returning a string that depends on id(self). This is relied on by flask_caching.utils.function_namespace - and our own memoize decorator. If this method were changed to no longer include some representation of id(self) - then cache keys would be mixed up. + In particular, it must depend on something that uniquely labels the data source and is the same across all + workers: self._name. Using id(self), as in Python's default repr, only works in the case that gunicorn is + running with --preload: without preloading, the id of the same data source in different processes will be + different so the cache will not match up. flask_caching make it possible to set a __cached_id__ attribute to handle this so that repr can be set independently of cache key, but this doesn't seem to be well documented or work well, so it's better to rely on __repr__. - - In future we might like to change the cache so that it works on data source name rather than the place in memory. - This would necessitate a new _DynamicData._name attribute. This would get us closer to getting gunicorn to work - without relying on --preload, although it would not get all the way there: - * model_manager would need to fix a random seed or alternative solution (just like Dash does for its - component ids) - * not clear how to handle the case of unnamed data sources, where the name is currently generated - automatically by the id, since without --preload this would give mismatched names. If use a random number with - fixed seed for this then lose advantage of multiple plots that use the same data source having just one - underlying dataframe in memory. - * would need to make it possible to disable cache at build time so that data would be loaded once at build - time and then again once at runtime, which is generally not what we want """ - return super().__repr__() + # Note that using repr(self.__load_data) is not good since it depends on the id of self.__load_data and so + # would not be consistent across processes. + return f"{self.__class__.__name__}({self._name}, {self.__load_data.__qualname__})" class _StaticData: @@ -148,7 +139,8 @@ class _StaticData: 2. to raise a clear error message if a user tries to set a timeout on the data source """ - def __init__(self, data: pd.DataFrame, /): + def __init__(self, data: pd.DataFrame): + # No need for _name here because static data doesn't need a cache key. self.__data = data def load(self) -> pd.DataFrame: @@ -203,7 +195,7 @@ def __setitem__(self, name: DataSourceName, data: Union[pd.DataFrame, pd_DataFra raise ValueError(f"Data source {name} already exists.") if callable(data): - self.__data[name] = _DynamicData(data) + self.__data[name] = _DynamicData(name, data) elif isinstance(data, pd.DataFrame): self.__data[name] = _StaticData(data) else: @@ -239,7 +231,8 @@ def _get_component_data(self, component_id: ComponentID) -> pd.DataFrame: if component_id not in self.__component_to_data: raise KeyError(f"Component {component_id} does not exist. You need to call add_component first.") name = self.__component_to_data[component_id] - logger.debug("Loading data %s with id %s", name, id(self[name])) + + logger.debug(f"Loading data %s on process %s", name, os.getpid()) return self[name].load() def _clear(self): @@ -252,7 +245,6 @@ def _clear(self): if __name__ == "__main__": - # TODO: remove this/move to docs from functools import partial import vizro.plotly.express as px diff --git a/vizro-core/src/vizro/managers/_model_manager.py b/vizro-core/src/vizro/managers/_model_manager.py index c2b154417..d52784acb 100644 --- a/vizro-core/src/vizro/managers/_model_manager.py +++ b/vizro-core/src/vizro/managers/_model_manager.py @@ -12,6 +12,8 @@ from vizro.models import VizroBaseModel from vizro.models._action._actions_chain import ActionsChain +# As done for Dash components in dash.development.base_component, fixing the random seed is required to make sure that +# the randomly generated model ID for the same model matches up across workers when running gunicorn without --preload. rd = random.Random(0) ModelID = NewType("ModelID", str) diff --git a/vizro-core/src/vizro/models/_components/_components_utils.py b/vizro-core/src/vizro/models/_components/_components_utils.py index 2a5871d94..81cd93fb9 100644 --- a/vizro-core/src/vizro/models/_components/_components_utils.py +++ b/vizro-core/src/vizro/models/_components/_components_utils.py @@ -1,3 +1,5 @@ +import uuid + import logging from functools import partial @@ -26,26 +28,23 @@ def _callable_mode_validator_factory(mode: str): def _process_callable_data_frame(captured_callable, values): data_frame = captured_callable["data_frame"] - # Enable running e.g. px.scatter("iris") from the Python API and specification of "data_frame": "iris" through JSON. - # In these cases, data already exists in the data manager and just needs to be linked to the component. if isinstance(data_frame, str): - data_manager._add_component(values["id"], data_frame) - return captured_callable - - # Standard case for px.scatter(df: pd.DataFrame). - # Extract dataframe from the captured function and put it into the data manager. - dataset_name = str(id(data_frame)) - - logger.debug("Adding data to data manager for Figure with id %s", values["id"]) - # If the dataset already exists in the data manager then it's not a problem, it just means that we don't need - # to duplicate it. Just log the exception for debugging purposes. - try: - data_manager[dataset_name] = data_frame - except ValueError as exc: - logger.debug(exc) - - data_manager._add_component(values["id"], dataset_name) - + # Enable running with DynamicData, e.g. px.scatter("iris") from the Python API and specification of + # "data_frame": "iris" through JSON. In these cases, data already exists in the data manager and just needs to be + # linked to the component. + data_source_name = data_frame + else: + # Standard StaticData case for px.scatter(df: pd.DataFrame). + # Extract dataframe from the captured function and put it into the data manager. + # Unlike with model_manager, it doesn't matter if the random seed is different across workers here. So long as we + # always fetch StaticData data from the data manager my going through the appropriate Figure component, the right + # data name will be fetched. It also doesn't matter that multiple Figures with the same underlying data + # each have their own entry in the data manager, since the underlying pd.DataFrame will still be the same and not + # copied into each one, so no memory is wasted. + logger.debug("Adding data to data manager for Figure with id %s", values["id"]) + data_source_name = str(uuid.uuid4()) + + data_manager._add_component(values["id"], data_source_name) # No need to keep the data in the captured function any more so remove it to save memory. del captured_callable["data_frame"] return captured_callable diff --git a/vizro-core/src/vizro/models/types.py b/vizro-core/src/vizro/models/types.py index a3e809666..518172600 100644 --- a/vizro-core/src/vizro/models/types.py +++ b/vizro-core/src/vizro/models/types.py @@ -282,10 +282,8 @@ def wrapped(*args, **kwargs) -> _DashboardReadyFigure: if isinstance(captured_callable["data_frame"], str): # Enable running e.g. px.scatter("iris") from the Python API. Don't actually run the function - # because it won't get work as there's no data. It's vital we don't fetch data from the data manager - # yet either, because otherwise all lazy data will be loaded before the dashboard is started. - # This case is not relevant for the JSON/YAML API, which is handled separately through validation of - # CapturedCallable. + # because it won't get work as there's no data. This case is not relevant for the JSON/YAML API, + # which is handled separately through validation of CapturedCallable. fig = _DashboardReadyFigure() else: # Standard case for px.scatter(df: pd.DataFrame).