diff --git a/ax/plot/pareto_utils.py b/ax/plot/pareto_utils.py index 00ccc8fe8c8..a6009705f07 100644 --- a/ax/plot/pareto_utils.py +++ b/ax/plot/pareto_utils.py @@ -579,7 +579,7 @@ def _build_new_optimization_config( def infer_reference_point_from_experiment( - experiment: Experiment, + experiment: Experiment, data: Data ) -> List[ObjectiveThreshold]: """This functions is a wrapper around ``infer_reference_point`` to find the nadir point from the pareto front of an experiment. Aside from converting experiment @@ -600,7 +600,8 @@ def infer_reference_point_from_experiment( # Reading experiment data. mb_reference = get_tensor_converter_model( - experiment=experiment, data=experiment.fetch_data() + experiment=experiment, + data=data, ) obs_feats, obs_data, _ = _get_modelbridge_training_data(modelbridge=mb_reference) diff --git a/ax/plot/tests/test_pareto_utils.py b/ax/plot/tests/test_pareto_utils.py index c3cd0a959fa..505185170db 100644 --- a/ax/plot/tests/test_pareto_utils.py +++ b/ax/plot/tests/test_pareto_utils.py @@ -248,7 +248,10 @@ def test_infer_reference_point_from_experiment(self) -> None: scalarized=False, constrained=False, ) - inferred_reference_point = infer_reference_point_from_experiment(experiment) + data = experiment.fetch_data() + inferred_reference_point = infer_reference_point_from_experiment( + experiment, data=data + ) # The nadir point for this experiment is [-0.5, 0.5]. The function actually # deducts 0.1*Y_range from each of the objectives. Since the range for each # of the objectives is +/-1.5, the inferred reference point would @@ -265,7 +268,7 @@ def test_infer_reference_point_from_experiment(self) -> None: return_value=([], [], [], []), ): with self.assertRaisesRegex(RuntimeError, "No frontier observations found"): - infer_reference_point_from_experiment(experiment) + infer_reference_point_from_experiment(experiment, data=data) def test_constrained_infer_reference_point_from_experiment(self) -> None: experiments = [] @@ -290,14 +293,15 @@ def test_constrained_infer_reference_point_from_experiment(self) -> None: for experiment in experiments: # special case logs a warning message. + data = experiment.fetch_data() if experiment.optimization_config.outcome_constraints[0].bound == 1000.0: with self.assertLogs(logger, "WARNING"): inferred_reference_point = infer_reference_point_from_experiment( - experiment + experiment, data=data ) else: inferred_reference_point = infer_reference_point_from_experiment( - experiment + experiment, data=data ) # The nadir point for this experiment is [-0.5, 0.5]. The function actually # deducts 0.1*Y_range from each of the objectives. Since the range for each @@ -377,7 +381,9 @@ def test_infer_reference_point_from_experiment_shuffled_metrics(self) -> None: obj_t_shuffled, ), ): - inferred_reference_point = infer_reference_point_from_experiment(experiment) + inferred_reference_point = infer_reference_point_from_experiment( + experiment, data=experiment.fetch_data() + ) self.assertEqual(inferred_reference_point[0].op, ComparisonOp.LEQ) self.assertEqual(inferred_reference_point[0].bound, -0.35) diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 5e7113f256e..3fcfb8639b2 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -450,10 +450,16 @@ def completion_criterion(self) -> Tuple[bool, str]: and len(self.experiment.trials_by_status[TrialStatus.COMPLETED]) >= gss.min_trials ): - # We infer the nadir reference point to be used by the GSS. - self.__inferred_reference_point = infer_reference_point_from_experiment( - self.experiment - ) + # only infer reference point if there is data on the experiment. + data = self.experiment.fetch_data() + if not data.df.empty: + # We infer the nadir reference point to be used by the GSS. + self.__inferred_reference_point = ( + infer_reference_point_from_experiment( + self.experiment, + data=data, + ) + ) stop_optimization, global_stopping_msg = gss.should_stop_optimization( experiment=self.experiment, diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index fa4fe4bc110..55d29966520 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -820,6 +820,40 @@ def test_inferring_reference_point(self) -> None: scheduler.run_n_trials(max_trials=10) mock_infer_rp.assert_called_once() + def test_inferring_reference_point_no_data(self) -> None: + init_test_engine_and_session_factory(force_init=True) + experiment = get_branin_experiment_with_multi_objective() + experiment.runner = self.runner + gs = self._get_generation_strategy_strategy_for_test( + experiment=experiment, + generation_strategy=self.sobol_GS_no_parallelism, + ) + + scheduler = Scheduler( + experiment=experiment, + generation_strategy=gs, + options=SchedulerOptions( + # Stops the optimization after 5 trials. + global_stopping_strategy=DummyGlobalStoppingStrategy( + min_trials=0, + trial_to_stop=5, + ), + ), + db_settings=self.db_settings, + ) + empty_data = Data( + df=pd.DataFrame( + columns=["metric_name", "arm_name", "trial_index", "mean", "sem"] + ) + ) + with patch( + "ax.service.scheduler.infer_reference_point_from_experiment" + ) as mock_infer_rp, patch.object( + scheduler.experiment, "fetch_data", return_value=empty_data + ): + scheduler.run_n_trials(max_trials=1) + mock_infer_rp.assert_not_called() + def test_global_stopping(self) -> None: gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment,