diff --git a/e2e_tests/test_server.py b/e2e_tests/test_server.py index 03083c61..b3d7ed49 100644 --- a/e2e_tests/test_server.py +++ b/e2e_tests/test_server.py @@ -8,8 +8,6 @@ from optuna_dashboard import wsgi import pytest -from .utils import clear_inmemory_cache - def get_free_port() -> int: tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -30,7 +28,7 @@ def make_test_server( thread.start() def stop_server() -> None: - clear_inmemory_cache() + app._inmemory_cache.clear() httpd.shutdown() httpd.server_close() thread.join() diff --git a/e2e_tests/utils.py b/e2e_tests/utils.py index ecd7d2cc..a7b22e5f 100644 --- a/e2e_tests/utils.py +++ b/e2e_tests/utils.py @@ -1,19 +1,6 @@ -from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache -from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache_lock -from optuna_dashboard._storage import trials_cache -from optuna_dashboard._storage import trials_cache_lock -from optuna_dashboard._storage import trials_last_fetched_at from playwright.sync_api import Page -def clear_inmemory_cache() -> None: - with trials_cache_lock: - trials_cache.clear() - trials_last_fetched_at.clear() - with cached_extra_study_property_cache_lock: - cached_extra_study_property_cache.clear() - - def count_components(page: Page, component_name: str): component_count = page.evaluate( f"""() => {{ diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index e1c040ac..f7dbc17c 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -27,9 +27,11 @@ from . import _note as note from ._bottle_util import BottleViewReturn from ._bottle_util import json_api_view -from ._cached_extra_study_property import get_cached_extra_study_property from ._custom_plot_data import get_plotly_graph_objects from ._importance import get_param_importance_from_trials_cache +from ._inmemory_cache import get_cached_extra_study_property +from ._inmemory_cache import get_trials +from ._inmemory_cache import InMemoryCache from ._pareto_front import get_pareto_front_trials from ._preference_setting import _register_preference_feedback_component from ._preferential_history import NewHistory @@ -43,7 +45,6 @@ from ._storage import create_new_study from ._storage import get_studies from ._storage import get_study -from ._storage import get_trials from ._storage_url import get_storage from .artifact._backend import delete_all_artifacts from .artifact._backend import register_artifact_route @@ -80,6 +81,7 @@ def create_app( debug: bool = False, ) -> Bottle: app = Bottle() + app._inmemory_cache = InMemoryCache() @app.hook("before_request") def remove_trailing_slashes_hook() -> None: @@ -214,7 +216,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: if study is None: response.status = 404 # Not found return {"reason": f"study_id={study_id} is not found"} - trials = get_trials(storage, study_id) + trials = get_trials(app._inmemory_cache, storage, study_id) system_attrs = getattr(study, "system_attrs", {}) is_preferential = system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False) @@ -235,7 +237,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: union, union_user_attrs, has_intermediate_values, - ) = get_cached_extra_study_property(study_id, trials) + ) = get_cached_extra_study_property(app._inmemory_cache, study_id, trials) plotly_graph_objects = get_plotly_graph_objects(system_attrs) skipped_trial_ids = get_skipped_trial_ids(system_attrs) @@ -261,10 +263,12 @@ def get_param_importances(study_id: int) -> dict[str, Any]: response.status = 404 # Study is not found return {"reason": f"study_id={study_id} is not found"} - trials = get_trials(storage, study_id) + trials = get_trials(app._inmemory_cache, storage, study_id) try: importances = [ - get_param_importance_from_trials_cache(storage, study_id, objective_id, trials) + get_param_importance_from_trials_cache( + storage, study_id, objective_id, trials, app._inmemory_cache + ) for objective_id in range(n_directions) ] return {"param_importances": importances} diff --git a/optuna_dashboard/_cached_extra_study_property.py b/optuna_dashboard/_cached_extra_study_property.py index 24e22372..989aff02 100644 --- a/optuna_dashboard/_cached_extra_study_property.py +++ b/optuna_dashboard/_cached_extra_study_property.py @@ -1,111 +1,111 @@ -from __future__ import annotations - -import copy -import numbers -import threading -from typing import List -from typing import Optional -from typing import Set -from typing import Tuple -from typing import TYPE_CHECKING - -from optuna.distributions import BaseDistribution -from optuna.trial import FrozenTrial -from optuna.trial import TrialState - - -# In-memory cache -cached_extra_study_property_cache_lock = threading.Lock() -cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {} - - -if TYPE_CHECKING: - SearchSpaceSetT = Set[Tuple[str, BaseDistribution]] - SearchSpaceListT = List[Tuple[str, BaseDistribution]] - - -def get_cached_extra_study_property( - study_id: int, trials: list[FrozenTrial] -) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]: - with cached_extra_study_property_cache_lock: - cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None) - if cached_extra_study_property is None: - cached_extra_study_property = _CachedExtraStudyProperty() - cached_extra_study_property.update(trials) - cached_extra_study_property_cache[study_id] = cached_extra_study_property - return ( - cached_extra_study_property.intersection_search_space, - cached_extra_study_property.union_search_space, - cached_extra_study_property.union_user_attrs, - cached_extra_study_property.has_intermediate_values, - ) - - -class _CachedExtraStudyProperty: - def __init__(self) -> None: - self._cursor: int = -1 - self._intersection_search_space: Optional[SearchSpaceSetT] = None - self._union_search_space: SearchSpaceSetT = set() - self._union_user_attrs: dict[str, bool] = {} # attr_name: is_sortable (= is_number) - self.has_intermediate_values: bool = False - - @property - def intersection_search_space(self) -> SearchSpaceListT: - if self._intersection_search_space is None: - return [] - intersection = list(self._intersection_search_space) - intersection.sort(key=lambda x: x[0]) - return intersection - - @property - def union_search_space(self) -> SearchSpaceListT: - union = list(self._union_search_space) - union.sort(key=lambda x: x[0]) - return union - - @property - def union_user_attrs(self) -> list[tuple[str, bool]]: - union = [(name, is_sortable) for name, is_sortable in self._union_user_attrs.items()] - sorted(union, key=lambda x: x[0]) - return union - - def update(self, trials: list[FrozenTrial]) -> None: - next_cursor = self._cursor - for trial in reversed(trials): - if self._cursor > trial.number: - break - - if not trial.state.is_finished(): - next_cursor = trial.number - - self._update_user_attrs(trial) - if trial.state != TrialState.FAIL: - self._update_intermediate_values(trial) - self._update_search_space(trial) - - self._cursor = next_cursor - - def _update_user_attrs(self, trial: FrozenTrial) -> None: - current_user_attrs = { - k: not isinstance(v, bool) and isinstance(v, numbers.Real) - for k, v in trial.user_attrs.items() - } - for attr_name, current_is_sortable in current_user_attrs.items(): - is_sortable = self._union_user_attrs.get(attr_name) - if is_sortable is None: - self._union_user_attrs[attr_name] = current_is_sortable - elif is_sortable and not current_is_sortable: - self._union_user_attrs[attr_name] = False - - def _update_intermediate_values(self, trial: FrozenTrial) -> None: - if not self.has_intermediate_values and len(trial.intermediate_values) > 0: - self.has_intermediate_values = True - - def _update_search_space(self, trial: FrozenTrial) -> None: - current = set([(n, d) for n, d in trial.distributions.items()]) - self._union_search_space = self._union_search_space.union(current) - - if self._intersection_search_space is None: - self._intersection_search_space = copy.copy(current) - else: - self._intersection_search_space = self._intersection_search_space.intersection(current) +# from __future__ import annotations + +# import copy +# import numbers +# import threading +# from typing import List +# from typing import Optional +# from typing import Set +# from typing import Tuple +# from typing import TYPE_CHECKING + +# from optuna.distributions import BaseDistribution +# from optuna.trial import FrozenTrial +# from optuna.trial import TrialState + + +# # In-memory cache +# cached_extra_study_property_cache_lock = threading.Lock() +# cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {} + + +# if TYPE_CHECKING: +# SearchSpaceSetT = Set[Tuple[str, BaseDistribution]] +# SearchSpaceListT = List[Tuple[str, BaseDistribution]] + + +# def get_cached_extra_study_property( +# study_id: int, trials: list[FrozenTrial] +# ) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]: +# with cached_extra_study_property_cache_lock: +# cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None) +# if cached_extra_study_property is None: +# cached_extra_study_property = _CachedExtraStudyProperty() +# cached_extra_study_property.update(trials) +# cached_extra_study_property_cache[study_id] = cached_extra_study_property +# return ( +# cached_extra_study_property.intersection_search_space, +# cached_extra_study_property.union_search_space, +# cached_extra_study_property.union_user_attrs, +# cached_extra_study_property.has_intermediate_values, +# ) + + +# class _CachedExtraStudyProperty: +# def __init__(self) -> None: +# self._cursor: int = -1 +# self._intersection_search_space: Optional[SearchSpaceSetT] = None +# self._union_search_space: SearchSpaceSetT = set() +# self._union_user_attrs: dict[str, bool] = {} # attr_name: is_sortable (= is_number) +# self.has_intermediate_values: bool = False + +# @property +# def intersection_search_space(self) -> SearchSpaceListT: +# if self._intersection_search_space is None: +# return [] +# intersection = list(self._intersection_search_space) +# intersection.sort(key=lambda x: x[0]) +# return intersection + +# @property +# def union_search_space(self) -> SearchSpaceListT: +# union = list(self._union_search_space) +# union.sort(key=lambda x: x[0]) +# return union + +# @property +# def union_user_attrs(self) -> list[tuple[str, bool]]: +# union = [(name, is_sortable) for name, is_sortable in self._union_user_attrs.items()] +# sorted(union, key=lambda x: x[0]) +# return union + +# def update(self, trials: list[FrozenTrial]) -> None: +# next_cursor = self._cursor +# for trial in reversed(trials): +# if self._cursor > trial.number: +# break + +# if not trial.state.is_finished(): +# next_cursor = trial.number + +# self._update_user_attrs(trial) +# if trial.state != TrialState.FAIL: +# self._update_intermediate_values(trial) +# self._update_search_space(trial) + +# self._cursor = next_cursor + +# def _update_user_attrs(self, trial: FrozenTrial) -> None: +# current_user_attrs = { +# k: not isinstance(v, bool) and isinstance(v, numbers.Real) +# for k, v in trial.user_attrs.items() +# } +# for attr_name, current_is_sortable in current_user_attrs.items(): +# is_sortable = self._union_user_attrs.get(attr_name) +# if is_sortable is None: +# self._union_user_attrs[attr_name] = current_is_sortable +# elif is_sortable and not current_is_sortable: +# self._union_user_attrs[attr_name] = False + +# def _update_intermediate_values(self, trial: FrozenTrial) -> None: +# if not self.has_intermediate_values and len(trial.intermediate_values) > 0: +# self.has_intermediate_values = True + +# def _update_search_space(self, trial: FrozenTrial) -> None: +# current = set([(n, d) for n, d in trial.distributions.items()]) +# self._union_search_space = self._union_search_space.union(current) + +# if self._intersection_search_space is None: +# self._intersection_search_space = copy.copy(current) +# else: +# self._intersection_search_space = self._intersection_search_space.intersection(current) diff --git a/optuna_dashboard/_importance.py b/optuna_dashboard/_importance.py index b7189412..da1be75d 100644 --- a/optuna_dashboard/_importance.py +++ b/optuna_dashboard/_importance.py @@ -11,7 +11,8 @@ from optuna.study import Study from optuna.trial import FrozenTrial from optuna.trial import TrialState -from optuna_dashboard._cached_extra_study_property import get_cached_extra_study_property +from optuna_dashboard._inmemory_cache import get_cached_extra_study_property +from optuna_dashboard._inmemory_cache import InMemoryCache _logger = logging.getLogger(__name__) @@ -97,7 +98,11 @@ def _get_param_importances( def get_param_importance_from_trials_cache( - storage: BaseStorage, study_id: int, objective_id: int, trials: list[FrozenTrial] + storage: BaseStorage, + study_id: int, + objective_id: int, + trials: list[FrozenTrial], + inmemory_cache: InMemoryCache, ) -> list[ImportanceType]: completed_trials = [t for t in trials if t.state == TrialState.COMPLETE] n_completed_trials = len(completed_trials) @@ -118,7 +123,9 @@ def get_param_importance_from_trials_cache( except RuntimeError: # RuntimeError("Encountered zero total variance in all trees.") may be raised # when all objective values are same. - _, union_search_space, _, _ = get_cached_extra_study_property(study_id, trials) + _, union_search_space, _, _ = get_cached_extra_study_property( + inmemory_cache, study_id, trials + ) importance_value = 1 / len(union_search_space) importance = { param_name: importance_value for param_name, distribution in union_search_space diff --git a/optuna_dashboard/_inmemory_cache.py b/optuna_dashboard/_inmemory_cache.py new file mode 100644 index 00000000..367d46c3 --- /dev/null +++ b/optuna_dashboard/_inmemory_cache.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import copy +from datetime import datetime +from datetime import timedelta +import numbers +import threading +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING + +from optuna.distributions import BaseDistribution +from optuna.storages import BaseStorage +from optuna.trial import FrozenTrial +from optuna.trial import TrialState + + +if TYPE_CHECKING: + SearchSpaceSetT = Set[Tuple[str, BaseDistribution]] + SearchSpaceListT = List[Tuple[str, BaseDistribution]] + + +def get_cached_extra_study_property( + in_memory_cache: InMemoryCache, study_id: int, trials: list[FrozenTrial] +) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]: + with in_memory_cache._cached_extra_study_property_cache_lock: + cached_extra_study_property = in_memory_cache._cached_extra_study_property_cache.get( + study_id, None + ) + if cached_extra_study_property is None: + cached_extra_study_property = _CachedExtraStudyProperty() + cached_extra_study_property.update(trials) + in_memory_cache._cached_extra_study_property_cache[study_id] = cached_extra_study_property + return ( + cached_extra_study_property.intersection_search_space, + cached_extra_study_property.union_search_space, + cached_extra_study_property.union_user_attrs, + cached_extra_study_property.has_intermediate_values, + ) + + +def get_trials( + in_memory_cache: InMemoryCache, storage: BaseStorage, study_id: int +) -> list[FrozenTrial]: + with in_memory_cache._trials_cache_lock: + trials = in_memory_cache._trials_cache.get(study_id, None) + + # Not a big fan of the heuristic, but I can't think of anything better. + if trials is None or len(trials) < 100: + ttl_seconds = 2 + elif len(trials) < 500: + ttl_seconds = 5 + else: + ttl_seconds = 10 + + last_fetched_at = in_memory_cache._trials_last_fetched_at.get(study_id, None) + if ( + trials is not None + and last_fetched_at is not None + and datetime.now() - last_fetched_at < timedelta(seconds=ttl_seconds) + ): + return trials + trials = storage.get_all_trials(study_id, deepcopy=False) + + with in_memory_cache._trials_cache_lock: + in_memory_cache._trials_last_fetched_at[study_id] = datetime.now() + in_memory_cache._trials_cache[study_id] = trials + return trials + + +class InMemoryCache: + def __init__(self): + self._cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {} + self._cached_extra_study_property_cache_lock = threading.Lock() + self._trials_cache: dict[int, list[FrozenTrial]] = {} + self._trials_cache_lock = threading.Lock() + self._trials_last_fetched_at: dict[int, datetime] = {} + + def clear(self): + self._cached_extra_study_property_cache.clear() + self._trials_cache.clear() + + +class _CachedExtraStudyProperty: + def __init__(self) -> None: + self._cursor: int = -1 + self._intersection_search_space: Optional[SearchSpaceSetT] = None + self._union_search_space: SearchSpaceSetT = set() + self._union_user_attrs: dict[str, bool] = {} # attr_name: is_sortable (= is_number) + self.has_intermediate_values: bool = False + + @property + def intersection_search_space(self) -> SearchSpaceListT: + if self._intersection_search_space is None: + return [] + intersection = list(self._intersection_search_space) + intersection.sort(key=lambda x: x[0]) + return intersection + + @property + def union_search_space(self) -> SearchSpaceListT: + union = list(self._union_search_space) + union.sort(key=lambda x: x[0]) + return union + + @property + def union_user_attrs(self) -> list[tuple[str, bool]]: + union = [(name, is_sortable) for name, is_sortable in self._union_user_attrs.items()] + sorted(union, key=lambda x: x[0]) + return union + + def update(self, trials: list[FrozenTrial]) -> None: + next_cursor = self._cursor + for trial in reversed(trials): + if self._cursor > trial.number: + break + + if not trial.state.is_finished(): + next_cursor = trial.number + + self._update_user_attrs(trial) + if trial.state != TrialState.FAIL: + self._update_intermediate_values(trial) + self._update_search_space(trial) + + self._cursor = next_cursor + + def _update_user_attrs(self, trial: FrozenTrial) -> None: + current_user_attrs = { + k: not isinstance(v, bool) and isinstance(v, numbers.Real) + for k, v in trial.user_attrs.items() + } + for attr_name, current_is_sortable in current_user_attrs.items(): + is_sortable = self._union_user_attrs.get(attr_name) + if is_sortable is None: + self._union_user_attrs[attr_name] = current_is_sortable + elif is_sortable and not current_is_sortable: + self._union_user_attrs[attr_name] = False + + def _update_intermediate_values(self, trial: FrozenTrial) -> None: + if not self.has_intermediate_values and len(trial.intermediate_values) > 0: + self.has_intermediate_values = True + + def _update_search_space(self, trial: FrozenTrial) -> None: + current = set([(n, d) for n, d in trial.distributions.items()]) + self._union_search_space = self._union_search_space.union(current) + + if self._intersection_search_space is None: + self._intersection_search_space = copy.copy(current) + else: + self._intersection_search_space = self._intersection_search_space.intersection(current) diff --git a/optuna_dashboard/_storage.py b/optuna_dashboard/_storage.py index 8fe10e0a..fda7d04c 100644 --- a/optuna_dashboard/_storage.py +++ b/optuna_dashboard/_storage.py @@ -17,33 +17,6 @@ trials_last_fetched_at: dict[int, datetime] = {} -def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]: - with trials_cache_lock: - trials = trials_cache.get(study_id, None) - - # Not a big fan of the heuristic, but I can't think of anything better. - if trials is None or len(trials) < 100: - ttl_seconds = 2 - elif len(trials) < 500: - ttl_seconds = 5 - else: - ttl_seconds = 10 - - last_fetched_at = trials_last_fetched_at.get(study_id, None) - if ( - trials is not None - and last_fetched_at is not None - and datetime.now() - last_fetched_at < timedelta(seconds=ttl_seconds) - ): - return trials - trials = storage.get_all_trials(study_id, deepcopy=False) - - with trials_cache_lock: - trials_last_fetched_at[study_id] = datetime.now() - trials_cache[study_id] = trials - return trials - - def get_studies(storage: BaseStorage) -> list[FrozenStudy]: frozen_studies = storage.get_all_studies() if isinstance(storage, RDBStorage): diff --git a/python_tests/test_cached_extra_study_property.py b/python_tests/test_cached_extra_study_property.py index 381cbcca..8e316719 100644 --- a/python_tests/test_cached_extra_study_property.py +++ b/python_tests/test_cached_extra_study_property.py @@ -11,7 +11,7 @@ from optuna.distributions import FloatDistribution from optuna.exceptions import ExperimentalWarning from optuna.trial import TrialState -from optuna_dashboard._cached_extra_study_property import _CachedExtraStudyProperty +from optuna_dashboard._inmemory_cache import _CachedExtraStudyProperty class _CachedExtraStudyPropertySearchSpaceTestCase(TestCase): diff --git a/python_tests/wsgi_client.py b/python_tests/wsgi_client.py index d0abef9a..0141dde4 100644 --- a/python_tests/wsgi_client.py +++ b/python_tests/wsgi_client.py @@ -4,9 +4,6 @@ import typing from bottle import Bottle -from optuna_dashboard._storage import trials_cache -from optuna_dashboard._storage import trials_cache_lock -from optuna_dashboard._storage import trials_last_fetched_at if typing.TYPE_CHECKING: @@ -14,12 +11,6 @@ from _typeshed.wsgi import WSGIEnvironment -def clear_inmemory_cache() -> None: - with trials_cache_lock: - trials_cache.clear() - trials_last_fetched_at.clear() - - def create_wsgi_env( path: str, method: str, @@ -77,7 +68,7 @@ def start_response( queries = queries or {} env = create_wsgi_env(path, method, content_type, bytes_body, queries, headers) - clear_inmemory_cache() + app._inmemory_cache.clear() response_body = b"" iterable_body = app(env, start_response) for b in iterable_body: