diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index d342a0e1c1d..aef9ac30e93 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -46,7 +46,7 @@ from ax.core.runner import Runner from ax.core.search_space import HierarchicalSearchSpace, RobustSearchSpace, SearchSpace from ax.core.trial import Trial -from ax.exceptions.storage import SQADecodeError +from ax.exceptions.storage import JSONDecodeError, SQADecodeError from ax.modelbridge.generation_strategy import GenerationStrategy from ax.storage.json_store.decoder import object_from_json from ax.storage.sqa_store.db import session_scope @@ -814,33 +814,46 @@ def generation_strategy_from_sqa( if experiment is not None else False ) - if reduced_state and gs_sqa.generator_runs: + + gs._generator_runs = [ + self.generator_run_from_sqa( + generator_run_sqa=gr, + reduced_state=reduced_state, + immutable_search_space_and_opt_config=immutable_ss_and_oc, + ) + for gr in gs_sqa.generator_runs[:-1] + ] + # This check is necessary to prevent an index error + # on `gs_sqa.generator_runs[-1]` + if gs_sqa.generator_runs: # Only fully load the last of the generator runs, load the rest with - # reduced state. - gs._generator_runs = [ - self.generator_run_from_sqa( - generator_run_sqa=gr, - reduced_state=True, - immutable_search_space_and_opt_config=immutable_ss_and_oc, + # reduced state. This is necessary for stateful models. The only + # stateful models available in open source ax is currently SOBOL. + try: + gs._generator_runs.append( + self.generator_run_from_sqa( + generator_run_sqa=gs_sqa.generator_runs[-1], + reduced_state=False, + immutable_search_space_and_opt_config=immutable_ss_and_oc, + ) ) - for gr in gs_sqa.generator_runs[:-1] - ] - gs._generator_runs.append( - self.generator_run_from_sqa( - generator_run_sqa=gs_sqa.generator_runs[-1], - reduced_state=False, - immutable_search_space_and_opt_config=immutable_ss_and_oc, + except JSONDecodeError: + if not reduced_state: + raise + + logger.exception( + "Failed to decode the last generator run because of the following " + "error. Loading with reduced state:" ) - ) - else: - gs._generator_runs = [ - self.generator_run_from_sqa( - generator_run_sqa=gr, - reduced_state=False, - immutable_search_space_and_opt_config=immutable_ss_and_oc, + # If the last generator run is not fully loadable, load it with + # reduced state. + gs._generator_runs.append( + self.generator_run_from_sqa( + generator_run_sqa=gs_sqa.generator_runs[-1], + reduced_state=True, + immutable_search_space_and_opt_config=immutable_ss_and_oc, + ) ) - for gr in gs_sqa.generator_runs - ] gs._experiment = experiment if len(gs._generator_runs) > 0: diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 75566a76330..b25005ed929 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -11,7 +11,7 @@ from decimal import Decimal from enum import Enum, unique from logging import Logger -from typing import Any +from typing import Any, Callable, TypeVar from unittest import mock from unittest.mock import MagicMock, Mock, patch @@ -32,7 +32,7 @@ from ax.core.runner import Runner from ax.core.types import ComparisonOp from ax.exceptions.core import ObjectNotFoundError -from ax.exceptions.storage import SQADecodeError, SQAEncodeError +from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError from ax.metrics.branin import BraninMetric from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.registry import Models @@ -77,6 +77,7 @@ SQAAbandonedArm, SQAArm, SQAExperiment, + SQAGenerationStrategy, SQAGeneratorRun, SQAMetric, SQAParameter, @@ -133,6 +134,7 @@ logger: Logger = get_logger(__name__) GET_GS_SQA_IMM_FUNC = _get_generation_strategy_sqa_immutable_opt_config_and_search_space +T = TypeVar("T") @unique @@ -564,6 +566,35 @@ def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None: ) self.assertEqual(loaded_experiment.trials[1], exp.trials[1]) + def test_load_gr_with_non_decodable_metadata_and_reduced_state(self) -> None: + def spy(original_method: Callable[..., T]) -> Callable[..., T]: + def wrapper(*args: Any, **kwargs: Any) -> T: + # Check if a specific argument is set to a certain value + if "reduced_state" in kwargs and not kwargs["reduced_state"]: + raise JSONDecodeError("Can't decode gen_metadata") + return original_method(*args, **kwargs) + + return wrapper + + gs = get_generation_strategy( + with_experiment=True, with_callable_model_kwarg=False + ) + gs.gen(gs.experiment) + gs.gen(gs.experiment) + + save_experiment(gs.experiment) + save_generation_strategy(gs) + + with self.assertLogs("ax", level=logging.ERROR): + with patch.object( + Decoder, "generator_run_from_sqa", spy(Decoder.generator_run_from_sqa) + ): + load_generation_strategy_by_id( + gs_id=none_throws(gs.db_id), + experiment=gs.experiment, + reduced_state=True, + ) + def test_MTExperimentSaveAndLoad(self) -> None: experiment = get_multi_type_experiment(add_trials=True) save_experiment(experiment)