Skip to content

Commit

Permalink
Handle errors loading last GR state with reduced_state (#3011)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3011

If we're using reduced state and getting an error because the last GR doesn't use reduced state, try with reduced state

Reviewed By: lena-kashtelyan

Differential Revision: D65350559
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Nov 1, 2024
1 parent d3dd862 commit 1adda28
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 26 deletions.
61 changes: 37 additions & 24 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ax.core.runner import Runner
from ax.core.search_space import HierarchicalSearchSpace, RobustSearchSpace, SearchSpace
from ax.core.trial import Trial
from ax.exceptions.storage import SQADecodeError
from ax.exceptions.storage import JSONDecodeError, SQADecodeError
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.storage.json_store.decoder import object_from_json
from ax.storage.sqa_store.db import session_scope
Expand Down Expand Up @@ -814,33 +814,46 @@ def generation_strategy_from_sqa(
if experiment is not None
else False
)
if reduced_state and gs_sqa.generator_runs:

gs._generator_runs = [
self.generator_run_from_sqa(
generator_run_sqa=gr,
reduced_state=reduced_state,
immutable_search_space_and_opt_config=immutable_ss_and_oc,
)
for gr in gs_sqa.generator_runs[:-1]
]
# This check is necessary to prevent an index error
# on `gs_sqa.generator_runs[-1]`
if gs_sqa.generator_runs:
# Only fully load the last of the generator runs, load the rest with
# reduced state.
gs._generator_runs = [
self.generator_run_from_sqa(
generator_run_sqa=gr,
reduced_state=True,
immutable_search_space_and_opt_config=immutable_ss_and_oc,
# reduced state. This is necessary for stateful models. The only
# stateful models available in open source ax is currently SOBOL.
try:
gs._generator_runs.append(
self.generator_run_from_sqa(
generator_run_sqa=gs_sqa.generator_runs[-1],
reduced_state=False,
immutable_search_space_and_opt_config=immutable_ss_and_oc,
)
)
for gr in gs_sqa.generator_runs[:-1]
]
gs._generator_runs.append(
self.generator_run_from_sqa(
generator_run_sqa=gs_sqa.generator_runs[-1],
reduced_state=False,
immutable_search_space_and_opt_config=immutable_ss_and_oc,
except JSONDecodeError:
if not reduced_state:
raise

logger.exception(
"Failed to decode the last generator run because of the following "
"error. Loading with reduced state:"
)
)
else:
gs._generator_runs = [
self.generator_run_from_sqa(
generator_run_sqa=gr,
reduced_state=False,
immutable_search_space_and_opt_config=immutable_ss_and_oc,
# If the last generator run is not fully loadable, load it with
# reduced state.
gs._generator_runs.append(
self.generator_run_from_sqa(
generator_run_sqa=gs_sqa.generator_runs[-1],
reduced_state=True,
immutable_search_space_and_opt_config=immutable_ss_and_oc,
)
)
for gr in gs_sqa.generator_runs
]
gs._experiment = experiment

if len(gs._generator_runs) > 0:
Expand Down
34 changes: 32 additions & 2 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from decimal import Decimal
from enum import Enum, unique
from logging import Logger
from typing import Any
from typing import Any, Callable, TypeVar
from unittest import mock
from unittest.mock import MagicMock, Mock, patch

Expand All @@ -32,7 +32,7 @@
from ax.core.runner import Runner
from ax.core.types import ComparisonOp
from ax.exceptions.core import ObjectNotFoundError
from ax.exceptions.storage import SQADecodeError, SQAEncodeError
from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError
from ax.metrics.branin import BraninMetric
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.registry import Models
Expand Down Expand Up @@ -133,6 +133,7 @@
logger: Logger = get_logger(__name__)

GET_GS_SQA_IMM_FUNC = _get_generation_strategy_sqa_immutable_opt_config_and_search_space
T = TypeVar("T")


@unique
Expand Down Expand Up @@ -564,6 +565,35 @@ def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None:
)
self.assertEqual(loaded_experiment.trials[1], exp.trials[1])

def test_load_gr_with_non_decodable_metadata_and_reduced_state(self) -> None:
def spy(original_method: Callable[..., T]) -> Callable[..., T]:
def wrapper(*args: Any, **kwargs: Any) -> T:
# Check if a specific argument is set to a certain value
if "reduced_state" in kwargs and not kwargs["reduced_state"]:
raise JSONDecodeError("Can't decode gen_metadata")
return original_method(*args, **kwargs)

return wrapper

gs = get_generation_strategy(
with_experiment=True, with_callable_model_kwarg=False
)
gs.gen(gs.experiment)
gs.gen(gs.experiment)

save_experiment(gs.experiment)
save_generation_strategy(gs)

with self.assertLogs("ax", level=logging.ERROR):
with patch.object(
Decoder, "generator_run_from_sqa", spy(Decoder.generator_run_from_sqa)
):
load_generation_strategy_by_id(
gs_id=none_throws(gs.db_id),
experiment=gs.experiment,
reduced_state=True,
)

def test_MTExperimentSaveAndLoad(self) -> None:
experiment = get_multi_type_experiment(add_trials=True)
save_experiment(experiment)
Expand Down

0 comments on commit 1adda28

Please sign in to comment.