Skip to content

Commit

Permalink
fea: allow both batch_initial_conditions and random sampling together.
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Nov 4, 2024
1 parent 24f659c commit 10e7209
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 24 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,7 @@ def optimize_objective(
bounds=free_feature_bounds,
q=q,
num_restarts=optimizer_options.get("num_restarts", 60),
raw_samples=optimizer_options.get("raw_samples", 1024),
raw_samples=optimizer_options.get("raw_samples", 1024), # NOTE potential behaviour change
options={
"batch_limit": optimizer_options.get("batch_limit", 8),
"maxiter": optimizer_options.get("maxiter", 200),
Expand Down
110 changes: 94 additions & 16 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def __post_init__(self) -> None:
f"shape is {batch_initial_conditions_shape}."
)

if (
self.raw_samples is not None
and (self.raw_samples - batch_initial_conditions_shape[-2]) > 0
and len(batch_initial_conditions_shape) == 3
and self.num_restarts is not None
and self.num_restarts != batch_initial_conditions_shape[0]
):
raise ValueError(
"If using `batch_initial_conditions` together with `raw_samples`, "
"the first repeat dimension of `batch_initial_conditions` must "
"match `num_restarts`."
)

elif self.ic_generator is None:
if self.nonlinear_inequality_constraints is not None:
raise RuntimeError(
Expand Down Expand Up @@ -253,22 +266,44 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor

initial_conditions_provided = opt_inputs.batch_initial_conditions is not None

required_raw_samples = opt_inputs.raw_samples
if initial_conditions_provided:
batch_initial_conditions = opt_inputs.batch_initial_conditions
provided_initial_conditions = opt_inputs.batch_initial_conditions
if opt_inputs.raw_samples is not None:
required_raw_samples -= provided_initial_conditions.shape[-2]
else:
provided_initial_conditions = None

if required_raw_samples is not None and required_raw_samples > 0:
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
batch_initial_conditions = opt_inputs.get_ic_generator()(
generated_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
raw_samples=opt_inputs.raw_samples,
raw_samples=required_raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**opt_inputs.ic_gen_kwargs,
)
else:
generated_initial_conditions = None

if provided_initial_conditions is not None and generated_initial_conditions is not None:
provided_initial_conditions = provided_initial_conditions.repeat(
opt_inputs.num_restarts, *([1] * (provided_initial_conditions.dim()-1))
)
batch_initial_conditions = torch.cat(
[provided_initial_conditions, generated_initial_conditions], dim=-2
) # should this be shuffled?
elif provided_initial_conditions is not None:
batch_initial_conditions = provided_initial_conditions
elif generated_initial_conditions is not None:
batch_initial_conditions = generated_initial_conditions
else:
raise ValueError("Either `batch_initial_conditions` or `raw_samples` must be set.")

batch_limit: int = options.get(
"batch_limit",
Expand Down Expand Up @@ -339,31 +374,39 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
first_warn_msg = (
"Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
"`batch_initial_conditions`, optimization will not be retried with "
"new initial conditions and will proceed with the current solution."
" Suggested remediation: Try again with different "
"`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
if initial_conditions_provided
"`batch_initial_conditions`>`raw_samples`, optimization will not "
"be retried with new initial conditions and will proceed with the "
"current solution. Suggested remediation: Try again with different "
"`batch_initial_conditions`, don't provide `batch_initial_conditions, "
"or increase `raw_samples`.`"
if required_raw_samples is not None and required_raw_samples > 0
else "Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
"set of initial conditions."
)
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)

if not initial_conditions_provided:
batch_initial_conditions = opt_inputs.get_ic_generator()(
if required_raw_samples is not None and required_raw_samples > 0:
generated_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
raw_samples=opt_inputs.raw_samples,
raw_samples=required_raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**opt_inputs.ic_gen_kwargs,
)

if provided_initial_conditions is not None:
batch_initial_conditions = torch.cat(
[provided_initial_conditions, generated_initial_conditions], dim=-2
) # should this be shuffled?
else:
batch_initial_conditions = generated_initial_conditions

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()

optimization_warning_raised = any(
Expand Down Expand Up @@ -1199,11 +1242,46 @@ def optimize_acqf_discrete_local_search(
inequality_constraints = inequality_constraints or []
for i in range(q):
# generate some starting points
if i == 0 and batch_initial_conditions is not None:
X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
X0 = _filter_infeasible(
X=X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
if i == 0:

if batch_initial_conditions is not None:
provided_X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
provided_X0 = _filter_infeasible(
X=provided_X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
if raw_samples is not None:
required_raw_samples = raw_samples - batch_initial_conditions.shape[-2]
else:
required_raw_samples = raw_samples
provided_X0 = None

if required_raw_samples > 0:
X_init = _gen_batch_initial_conditions_local_search(
discrete_choices=discrete_choices,
raw_samples=required_raw_samples,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=num_restarts,
)
# pick the best starting points
with torch.no_grad():
acqvals_init = _split_batch_eval_acqf(
acq_function=acq_function,
X=X_init.unsqueeze(1),
max_batch_size=max_batch_size,
).unsqueeze(-1)
generated_X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices]

if provided_X0 is not None and generated_X0 is not None:
provided_X0 = provided_X0.repeat(num_restarts, *([1] * (provided_X0.ndim - 1)))
X0 = torch.cat([provided_X0, generated_X0], dim=-2)
elif provided_X0 is not None:
X0 = provided_X0
elif generated_X0 is not None:
X0 = generated_X0
else:
raise ValueError("Either `batch_initial_conditions` or `raw_samples` must be set.")

else:
X_init = _gen_batch_initial_conditions_local_search(
discrete_choices=discrete_choices,
Expand Down
46 changes: 39 additions & 7 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,26 @@ def test_optimize_acqf_joint(
cnt += 1
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test generation with provided initial conditions
# test generation with provided initial conditions less than raw_samples
candidates, acq_vals = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
q=q,
num_restarts=num_restarts,
raw_samples=3,
options=options,
return_best_only=False,
batch_initial_conditions=torch.zeros(
num_restarts, q, 3, device=self.device, dtype=dtype
),
gen_candidates=mock_gen_candidates,
)
self.assertTrue(torch.equal(candidates, mock_candidates))
self.assertTrue(torch.equal(acq_vals, mock_acq_values))
cnt += 1
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test generation with provided initial conditions greater than raw_samples
candidates, acq_vals = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
Expand Down Expand Up @@ -543,7 +562,15 @@ def test_optimize_acqf_batch_limit(self) -> None:
gen_candidates=gen_candidates,
batch_initial_conditions=ics,
)
expected_shape = (num_restarts,) if ics is None else (ics.shape[0],)
expected_shape = (
(num_restarts,)
if ics is None
else (
(ics.shape[0],)
if ics.shape[0] > raw_samples
else (ics.shape[0]*num_restarts,)
)
)
self.assertEqual(acq_value_list.shape, expected_shape)

def test_optimize_acqf_runs_given_batch_initial_conditions(self):
Expand Down Expand Up @@ -635,11 +662,12 @@ def test_optimize_acqf_warns_on_opt_failure(self):
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
"_IN_LNSRCH.')]\nBecause you specified `batch_initial_conditions`, "
"optimization will not be retried with new initial conditions and will "
"proceed with the current solution. Suggested remediation: Try again with "
"different `batch_initial_conditions`, or don't provide "
"`batch_initial_conditions.`"
"_IN_LNSRCH.')]\nBecause you specified "
"`batch_initial_conditions`>`raw_samples`, optimization will not "
"be retried with new initial conditions and will proceed with the "
"current solution. Suggested remediation: Try again with different "
"`batch_initial_conditions`, don't provide `batch_initial_conditions, "
"or increase `raw_samples`.`"
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
Expand Down Expand Up @@ -1841,3 +1869,7 @@ def my_gen():
)
ic_generator = opt_inputs.get_ic_generator()
self.assertIs(ic_generator, my_gen)

if __name__ == "__main__":
import pytest
pytest.main([__file__])

0 comments on commit 10e7209

Please sign in to comment.