@@ -999,10 +999,12 @@ def sparse_to_dense_constraints(
999999def optimize_posterior_samples (
10001000 paths : GenericDeterministicModel ,
10011001 bounds : Tensor ,
1002- raw_samples : int = 1024 ,
1003- num_restarts : int = 20 ,
1002+ raw_samples : int = 2048 ,
1003+ num_restarts : int = 4 ,
10041004 sample_transform : Callable [[Tensor ], Tensor ] | None = None ,
10051005 return_transformed : bool = False ,
1006+ suggested_points : Tensor | None = None ,
1007+ options : dict | None = None ,
10061008) -> tuple [Tensor , Tensor ]:
10071009 r"""Cheaply maximizes posterior samples by random querying followed by
10081010 gradient-based optimization using SciPy's L-BFGS-B routine.
@@ -1011,19 +1013,27 @@ def optimize_posterior_samples(
10111013 paths: Random Fourier Feature-based sample paths from the GP
10121014 bounds: The bounds on the search space.
10131015 raw_samples: The number of samples with which to query the samples initially.
1016+ Raw samples are cheap to evaluate, so this should ideally be set much higher
1017+ than num_restarts.
10141018 num_restarts: The number of points selected for gradient-based optimization.
1019+ Should be set low relative to the number of raw
10151020 sample_transform: A callable transform of the sample outputs (e.g.
10161021 MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to
10171022 negate the objective or otherwise transform the output.
10181023 return_transformed: A boolean indicating whether to return the transformed
10191024 or non-transformed samples.
1025+ suggested_points: Tensor of suggested input locations that are high-valued.
1026+ These are more densely evaluated during the sampling phase of optimization.
1027+ options: Options for generation of initial candidates, passed to
1028+ gen_batch_initial_conditions.
10201029
10211030 Returns:
10221031 A two-element tuple containing:
10231032 - X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
10241033 - f_opt: A `num_optima x [batch_size] x m`-dim, optionally
10251034 `num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*.
10261035 """
1036+ options = {} if options is None else options
10271037
10281038 def path_func (x ) -> Tensor :
10291039 res = paths (x )
@@ -1032,21 +1042,35 @@ def path_func(x) -> Tensor:
10321042
10331043 return res .squeeze (- 1 )
10341044
1035- candidate_set = unnormalize (
1036- SobolEngine (dimension = bounds .shape [1 ], scramble = True ).draw (n = raw_samples ),
1037- bounds = bounds ,
1038- )
10391045 # queries all samples on all candidates - output shape
10401046 # raw_samples * num_optima * num_models
1047+ frac_random = 1 if suggested_points is None else options .get ("frac_random" , 0.9 )
1048+ candidate_set = draw_sobol_samples (
1049+ bounds = bounds , n = round (raw_samples * frac_random ), q = 1
1050+ ).squeeze (- 2 )
1051+ if frac_random < 1 :
1052+ perturbed_suggestions = sample_truncated_normal_perturbations (
1053+ X = suggested_points ,
1054+ n_discrete_points = round (raw_samples * (1 - frac_random )),
1055+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
1056+ bounds = bounds ,
1057+ )
1058+ candidate_set = torch .cat ((candidate_set , perturbed_suggestions ))
1059+
10411060 candidate_queries = path_func (candidate_set )
1042- argtop_k = torch .topk (candidate_queries , num_restarts , dim = - 1 ).indices
1043- X_top_k = candidate_set [argtop_k , :]
1061+ idx = boltzmann_sample (
1062+ function_values = candidate_queries .unsqueeze (- 1 ),
1063+ num_samples = num_restarts ,
1064+ eta = options .get ("eta" , 5.0 ),
1065+ replacement = False ,
1066+ )
1067+ ics = candidate_set [idx , :]
10441068
10451069 # to avoid circular import, the import occurs here
10461070 from botorch .generation .gen import gen_candidates_scipy
10471071
10481072 X_top_k , f_top_k = gen_candidates_scipy (
1049- X_top_k ,
1073+ ics ,
10501074 path_func ,
10511075 lower_bounds = bounds [0 ],
10521076 upper_bounds = bounds [1 ],
@@ -1100,8 +1124,9 @@ def boltzmann_sample(
11001124 eta *= temp_decrease
11011125 weights = torch .exp (eta * norm_weights )
11021126
1127+ # squeeze in case of m = 1 (mono-output provided as batch_size x N x 1)
11031128 return batched_multinomial (
1104- weights = weights , num_samples = num_samples , replacement = replacement
1129+ weights = weights . squeeze ( - 1 ) , num_samples = num_samples , replacement = replacement
11051130 )
11061131
11071132
0 commit comments