Skip to content

Commit 683716d

Browse files
blethamfacebook-github-bot
authored andcommitted
optimize sequence of acquisition functions (#2931)
Summary: Pull Request resolved: #2931 Enables sequential q-batch optimization using a sequence of acquisition functions rather than the same acquisition function for each point in the batch. Right now just for when using optimize_acqf, thus focused on continuous search spaces. Differential Revision: D78560867
1 parent 9342a9e commit 683716d

File tree

2 files changed

+108
-15
lines changed

2 files changed

+108
-15
lines changed

botorch/optim/optimize.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class OptimizeAcqfInputs:
8585
return_full_tree: bool = False
8686
retry_on_optimization_warning: bool = True
8787
ic_gen_kwargs: dict = dataclasses.field(default_factory=dict)
88+
acqf_sequence: list[AcquisitionFunction] | None = None
8889

8990
@property
9091
def full_tree(self) -> bool:
@@ -168,6 +169,14 @@ def __post_init__(self) -> None:
168169
):
169170
raise ValueError("All indices (keys) in `fixed_features` must be >= 0.")
170171

172+
if self.acqf_sequence is not None:
173+
if not self.sequential:
174+
raise ValueError("acqf_sequence requires sequential optimization.")
175+
if len(self.acqf_sequence) != self.q:
176+
raise ValueError("acqf_sequence must have length q.")
177+
if self.q < 2:
178+
raise ValueError("acqf_sequence requires q > 1.")
179+
171180
def get_ic_generator(self) -> TGenInitialConditions:
172181
if self.ic_generator is not None:
173182
return self.ic_generator
@@ -266,26 +275,35 @@ def _optimize_acqf_sequential_q(
266275
candidate_list, acq_value_list = [], []
267276
base_X_pending = opt_inputs.acq_function.X_pending
268277

269-
new_inputs = dataclasses.replace(
270-
opt_inputs,
271-
q=1,
272-
batch_initial_conditions=None,
273-
return_best_only=True,
274-
sequential=False,
275-
timeout_sec=timeout_sec,
276-
)
278+
new_kwargs = {
279+
"q": 1,
280+
"batch_initial_conditions": None,
281+
"return_best_only": True,
282+
"sequential": False,
283+
"timeout_sec": timeout_sec,
284+
"acqf_sequence": None,
285+
}
286+
277287
for i in range(opt_inputs.q):
288+
if opt_inputs.acqf_sequence is not None:
289+
new_kwargs["acq_function"] = opt_inputs.acqf_sequence[i]
290+
new_inputs = dataclasses.replace(opt_inputs, **new_kwargs)
291+
if len(candidate_list) > 0:
292+
candidates = torch.cat(candidate_list, dim=-2)
293+
new_inputs.acq_function.set_X_pending(
294+
torch.cat([base_X_pending, candidates], dim=-2)
295+
if base_X_pending is not None
296+
else candidates
297+
)
278298
candidate, acq_value = _optimize_acqf_batch(new_inputs)
279299

280300
candidate_list.append(candidate)
281301
acq_value_list.append(acq_value)
282-
candidates = torch.cat(candidate_list, dim=-2)
283-
new_inputs.acq_function.set_X_pending(
284-
torch.cat([base_X_pending, candidates], dim=-2)
285-
if base_X_pending is not None
286-
else candidates
287-
)
302+
288303
logger.info(f"Generated sequential candidate {i + 1} of {opt_inputs.q}")
304+
model_name = type(new_inputs.acq_function.model).__name__
305+
logger.debug(f"Used model {model_name} for candidate generation.")
306+
candidates = torch.cat(candidate_list, dim=-2)
289307
opt_inputs.acq_function.set_X_pending(base_X_pending)
290308
return candidates, torch.stack(acq_value_list)
291309

@@ -532,6 +550,7 @@ def optimize_acqf(
532550
return_best_only: bool = True,
533551
gen_candidates: TGenCandidates | None = None,
534552
sequential: bool = False,
553+
acqf_sequence: list[AcquisitionFunction] | None = None,
535554
*,
536555
ic_generator: TGenInitialConditions | None = None,
537556
timeout_sec: float | None = None,
@@ -627,6 +646,10 @@ def optimize_acqf(
627646
inputs. Default: `gen_candidates_scipy`
628647
sequential: If False, uses joint optimization, otherwise uses sequential
629648
optimization for optimizing multiple joint candidates (q > 1).
649+
acqf_sequence: A list of acquisition functions to be optimized sequentially.
650+
Must be of length q>1, and requires sequential=True. Used for ensembling
651+
candidates from different acquisition functions. If omitted, use
652+
`acq_function` to generate all `q` candidates.
630653
ic_generator: Function for generating initial conditions. Not needed when
631654
`batch_initial_conditions` are provided. Defaults to
632655
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
@@ -689,6 +712,7 @@ def optimize_acqf(
689712
return_full_tree=return_full_tree,
690713
retry_on_optimization_warning=retry_on_optimization_warning,
691714
ic_gen_kwargs=ic_gen_kwargs,
715+
acqf_sequence=acqf_sequence,
692716
)
693717
return _optimize_acqf(opt_inputs=opt_acqf_inputs)
694718

@@ -707,7 +731,9 @@ def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
707731
)
708732

709733
# Perform sequential optimization via successive conditioning on pending points
710-
if opt_inputs.sequential and opt_inputs.q > 1:
734+
if (
735+
opt_inputs.sequential and opt_inputs.q > 1
736+
) or opt_inputs.acqf_sequence is not None:
711737
return _optimize_acqf_sequential_q(opt_inputs=opt_inputs)
712738

713739
# Batch optimization (including the case q=1)

test/optim/test_optimize.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,73 @@ def test_optimize_acqf_sequential(
449449
sequential=True,
450450
)
451451

452+
@mock.patch(
453+
"botorch.optim.optimize.gen_candidates_scipy", wraps=gen_candidates_scipy
454+
)
455+
def test_optimize_acqf_sequence(
456+
self,
457+
mock_gen_candidates_scipy,
458+
):
459+
acqf_sequence = [MockAcquisitionFunction() for _ in range(3)]
460+
bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
461+
# Validation
462+
with self.assertRaisesRegex(
463+
ValueError,
464+
"acqf_sequence requires sequential optimization",
465+
):
466+
optimize_acqf(
467+
acq_function=mock.MagicMock(),
468+
bounds=bounds,
469+
q=3,
470+
num_restarts=2,
471+
raw_samples=10,
472+
sequential=False,
473+
acqf_sequence=acqf_sequence,
474+
)
475+
with self.assertRaisesRegex(
476+
ValueError,
477+
"acqf_sequence must have length q",
478+
):
479+
optimize_acqf(
480+
acq_function=mock.MagicMock(),
481+
bounds=bounds,
482+
q=2,
483+
num_restarts=2,
484+
raw_samples=10,
485+
sequential=True,
486+
acqf_sequence=acqf_sequence,
487+
)
488+
with self.assertRaisesRegex(
489+
ValueError,
490+
"acqf_sequence requires q > 1",
491+
):
492+
optimize_acqf(
493+
acq_function=mock.MagicMock(),
494+
bounds=bounds,
495+
q=1,
496+
num_restarts=2,
497+
raw_samples=10,
498+
sequential=True,
499+
acqf_sequence=acqf_sequence[:1],
500+
)
501+
# Test that uses sequence of acquisitions
502+
acq_function = mock.MagicMock()
503+
acq_function.X_pending = None
504+
_ = optimize_acqf(
505+
acq_function=acq_function,
506+
bounds=bounds,
507+
q=3,
508+
num_restarts=2,
509+
raw_samples=10,
510+
sequential=True,
511+
acqf_sequence=acqf_sequence,
512+
)
513+
self.assertEqual(mock_gen_candidates_scipy.call_count, 3)
514+
self.assertIsNone(acqf_sequence[0].X_pending)
515+
for i in range(1, 3):
516+
self.assertEqual(len(acqf_sequence[i].X_pending), i)
517+
acq_function.assert_not_called()
518+
452519
@mock.patch(
453520
"botorch.generation.gen.minimize_with_timeout",
454521
wraps=minimize_with_timeout,

0 commit comments

Comments
 (0)