Skip to content

Commit

Permalink
fix: raise errors for the n_repeats mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Nov 4, 2024
1 parent 10e7209 commit def7938
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 18 deletions.
40 changes: 36 additions & 4 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,14 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
generated_initial_conditions = None

if provided_initial_conditions is not None and generated_initial_conditions is not None:
provided_initial_conditions = provided_initial_conditions.repeat(
opt_inputs.num_restarts, *([1] * (provided_initial_conditions.dim()-1))
)
if ( # Repeat the provided initial conditions to match the number of restarts
provided_initial_conditions.shape[0] == 1
and opt_inputs.num_restarts is not None
and opt_inputs.num_restarts > 1
):
provided_initial_conditions = provided_initial_conditions.repeat(
opt_inputs.num_restarts, *([1] * (provided_initial_conditions.dim()-1))
)
batch_initial_conditions = torch.cat(
[provided_initial_conditions, generated_initial_conditions], dim=-2
) # should this be shuffled?
Expand Down Expand Up @@ -377,7 +382,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
"`batch_initial_conditions`>`raw_samples`, optimization will not "
"be retried with new initial conditions and will proceed with the "
"current solution. Suggested remediation: Try again with different "
"`batch_initial_conditions`, don't provide `batch_initial_conditions, "
"`batch_initial_conditions`, don't provide `batch_initial_conditions`, "
"or increase `raw_samples`.`"
if required_raw_samples is not None and required_raw_samples > 0
else "Optimization failed in `gen_candidates_scipy` with the following "
Expand Down Expand Up @@ -1230,6 +1235,25 @@ def optimize_acqf_discrete_local_search(
- a `q x d`-dim tensor of generated candidates.
- an associated acquisition value.
"""
if (
batch_initial_conditions is not None
and len(batch_initial_conditions.shape) != 3
):
raise ValueError("`batch_initial_conditions` must be of shape `n x 1 x d`.")

if (
raw_samples is not None
and batch_initial_conditions is not None
and (raw_samples - batch_initial_conditions.shape[-2]) > 0
and num_restarts is not None
and num_restarts != batch_initial_conditions.shape[0]
):
raise ValueError(
"If using `batch_initial_conditions` together with `raw_samples`, "
"the first repeat dimension of `batch_initial_conditions` must "
"match `num_restarts`."
)

candidate_list = []
base_X_pending = acq_function.X_pending if q > 1 else None
base_X_avoid = X_avoid
Expand All @@ -1245,12 +1269,20 @@ def optimize_acqf_discrete_local_search(
if i == 0:

if batch_initial_conditions is not None:
# FIXME not sure about given `n x 1 x d` why is only a single input allowed?
provided_X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
provided_X0 = _filter_infeasible(
X=provided_X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
if raw_samples is not None:
required_raw_samples = raw_samples - batch_initial_conditions.shape[-2]

if ( # Repeat the provided initial conditions to match the number of restarts
provided_X0.shape[0] == 1
and num_restarts is not None
and num_restarts > 1
):
provided_X0 = provided_X0.repeat(num_restarts, 1, 1)
else:
required_raw_samples = raw_samples
provided_X0 = None
Expand Down
34 changes: 20 additions & 14 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_optimize_acqf_joint(
cnt += 1
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test generation with provided initial conditions less than raw_samples
# test generation with provided initial conditions larger than raw_samples
candidates, acq_vals = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
Expand All @@ -183,10 +183,9 @@ def test_optimize_acqf_joint(
)
self.assertTrue(torch.equal(candidates, mock_candidates))
self.assertTrue(torch.equal(acq_vals, mock_acq_values))
cnt += 1
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test generation with provided initial conditions greater than raw_samples
# test generation with provided initial conditions less than raw_samples
candidates, acq_vals = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
Expand All @@ -202,6 +201,7 @@ def test_optimize_acqf_joint(
)
self.assertTrue(torch.equal(candidates, mock_candidates))
self.assertTrue(torch.equal(acq_vals, mock_acq_values))
cnt += 1
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test fixed features
Expand Down Expand Up @@ -661,13 +661,9 @@ def test_optimize_acqf_warns_on_opt_failure(self):
message = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
"_IN_LNSRCH.')]\nBecause you specified "
"`batch_initial_conditions`>`raw_samples`, optimization will not "
"be retried with new initial conditions and will proceed with the "
"current solution. Suggested remediation: Try again with different "
"`batch_initial_conditions`, don't provide `batch_initial_conditions, "
"or increase `raw_samples`.`"
"`scipy.optimize.minimize` with status 2 and message "
"ABNORMAL_TERMINATION_IN_LNSRCH.')]\nTrying again with a new set "
"of initial conditions."
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
Expand Down Expand Up @@ -711,8 +707,13 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
message = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
"_IN_LNSRCH.')]\nTrying again with a new set of initial conditions."
"`scipy.optimize.minimize` with status 2 and message "
"ABNORMAL_TERMINATION_IN_LNSRCH.')]\nBecause you specified "
"`batch_initial_conditions`>`raw_samples`, optimization will not "
"be retried with new initial conditions and will proceed with the "
"current solution. Suggested remediation: Try again with different "
"`batch_initial_conditions`, don't provide `batch_initial_conditions`, "
"or increase `raw_samples`.`"
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
Expand Down Expand Up @@ -777,8 +778,13 @@ def test_optimize_acqf_warns_on_second_opt_failure(self):
message_1 = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
"_IN_LNSRCH.')]\nTrying again with a new set of initial conditions."
"`scipy.optimize.minimize` with status 2 and message "
"ABNORMAL_TERMINATION_IN_LNSRCH.')]\nBecause you specified "
"`batch_initial_conditions`>`raw_samples`, optimization will "
"not be retried with new initial conditions and will proceed with "
"the current solution. Suggested remediation: Try again with "
"different `batch_initial_conditions`, don't provide "
"`batch_initial_conditions`, or increase `raw_samples`.`"
)

message_2 = (
Expand Down

0 comments on commit def7938

Please sign in to comment.