@@ -85,6 +85,7 @@ class OptimizeAcqfInputs:
85
85
return_full_tree : bool = False
86
86
retry_on_optimization_warning : bool = True
87
87
ic_gen_kwargs : dict = dataclasses .field (default_factory = dict )
88
+ acqf_sequence : list [AcquisitionFunction ] | None = None
88
89
89
90
@property
90
91
def full_tree (self ) -> bool :
@@ -168,6 +169,14 @@ def __post_init__(self) -> None:
168
169
):
169
170
raise ValueError ("All indices (keys) in `fixed_features` must be >= 0." )
170
171
172
+ if self .acqf_sequence is not None :
173
+ if not self .sequential :
174
+ raise ValueError ("acqf_sequence requires sequential optimization." )
175
+ if len (self .acqf_sequence ) != self .q :
176
+ raise ValueError ("acqf_sequence must have length q." )
177
+ if self .q < 2 :
178
+ raise ValueError ("acqf_sequence requires q > 1." )
179
+
171
180
def get_ic_generator (self ) -> TGenInitialConditions :
172
181
if self .ic_generator is not None :
173
182
return self .ic_generator
@@ -266,26 +275,35 @@ def _optimize_acqf_sequential_q(
266
275
candidate_list , acq_value_list = [], []
267
276
base_X_pending = opt_inputs .acq_function .X_pending
268
277
269
- new_inputs = dataclasses .replace (
270
- opt_inputs ,
271
- q = 1 ,
272
- batch_initial_conditions = None ,
273
- return_best_only = True ,
274
- sequential = False ,
275
- timeout_sec = timeout_sec ,
276
- )
278
+ new_kwargs = {
279
+ "q" : 1 ,
280
+ "batch_initial_conditions" : None ,
281
+ "return_best_only" : True ,
282
+ "sequential" : False ,
283
+ "timeout_sec" : timeout_sec ,
284
+ "acqf_sequence" : None ,
285
+ }
286
+
277
287
for i in range (opt_inputs .q ):
288
+ if opt_inputs .acqf_sequence is not None :
289
+ new_kwargs ["acq_function" ] = opt_inputs .acqf_sequence [i ]
290
+ new_inputs = dataclasses .replace (opt_inputs , ** new_kwargs )
291
+ if len (candidate_list ) > 0 :
292
+ candidates = torch .cat (candidate_list , dim = - 2 )
293
+ new_inputs .acq_function .set_X_pending (
294
+ torch .cat ([base_X_pending , candidates ], dim = - 2 )
295
+ if base_X_pending is not None
296
+ else candidates
297
+ )
278
298
candidate , acq_value = _optimize_acqf_batch (new_inputs )
279
299
280
300
candidate_list .append (candidate )
281
301
acq_value_list .append (acq_value )
282
- candidates = torch .cat (candidate_list , dim = - 2 )
283
- new_inputs .acq_function .set_X_pending (
284
- torch .cat ([base_X_pending , candidates ], dim = - 2 )
285
- if base_X_pending is not None
286
- else candidates
287
- )
302
+
288
303
logger .info (f"Generated sequential candidate { i + 1 } of { opt_inputs .q } " )
304
+ model_name = type (new_inputs .acq_function .model ).__name__
305
+ logger .debug (f"Used model { model_name } for candidate generation." )
306
+ candidates = torch .cat (candidate_list , dim = - 2 )
289
307
opt_inputs .acq_function .set_X_pending (base_X_pending )
290
308
return candidates , torch .stack (acq_value_list )
291
309
@@ -532,6 +550,7 @@ def optimize_acqf(
532
550
return_best_only : bool = True ,
533
551
gen_candidates : TGenCandidates | None = None ,
534
552
sequential : bool = False ,
553
+ acqf_sequence : list [AcquisitionFunction ] | None = None ,
535
554
* ,
536
555
ic_generator : TGenInitialConditions | None = None ,
537
556
timeout_sec : float | None = None ,
@@ -627,6 +646,10 @@ def optimize_acqf(
627
646
inputs. Default: `gen_candidates_scipy`
628
647
sequential: If False, uses joint optimization, otherwise uses sequential
629
648
optimization for optimizing multiple joint candidates (q > 1).
649
+ acqf_sequence: A list of acquisition functions to be optimized sequentially.
650
+ Must be of length q>1, and requires sequential=True. Used for ensembling
651
+ candidates from different acquisition functions. If omitted, use
652
+ `acq_function` to generate all `q` candidates.
630
653
ic_generator: Function for generating initial conditions. Not needed when
631
654
`batch_initial_conditions` are provided. Defaults to
632
655
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
@@ -689,6 +712,7 @@ def optimize_acqf(
689
712
return_full_tree = return_full_tree ,
690
713
retry_on_optimization_warning = retry_on_optimization_warning ,
691
714
ic_gen_kwargs = ic_gen_kwargs ,
715
+ acqf_sequence = acqf_sequence ,
692
716
)
693
717
return _optimize_acqf (opt_inputs = opt_acqf_inputs )
694
718
@@ -707,7 +731,9 @@ def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
707
731
)
708
732
709
733
# Perform sequential optimization via successive conditioning on pending points
710
- if opt_inputs .sequential and opt_inputs .q > 1 :
734
+ if (
735
+ opt_inputs .sequential and opt_inputs .q > 1
736
+ ) or opt_inputs .acqf_sequence is not None :
711
737
return _optimize_acqf_sequential_q (opt_inputs = opt_inputs )
712
738
713
739
# Batch optimization (including the case q=1)
0 commit comments