@@ -468,6 +468,90 @@ def gen_batch_initial_conditions(
468468    return  batch_initial_conditions 
469469
470470
471+ def  gen_optimal_input_initial_conditions (
472+     acq_function : AcquisitionFunction ,
473+     bounds : Tensor ,
474+     q : int ,
475+     num_restarts : int ,
476+     raw_samples : int ,
477+     fixed_features : dict [int , float ] |  None  =  None ,
478+     options : dict [str , bool  |  float  |  int ] |  None  =  None ,
479+     inequality_constraints : list [tuple [Tensor , Tensor , float ]] |  None  =  None ,
480+     equality_constraints : list [tuple [Tensor , Tensor , float ]] |  None  =  None ,
481+ ):
482+     device  =  bounds .device 
483+     if  not  hasattr (acq_function , "optimal_inputs" ):
484+         raise  AttributeError (
485+             "gen_optimal_input_initial_conditions can only be used with " 
486+             "an AcquisitionFunction that has an optimal_inputs attribute." 
487+         )
488+     frac_random : float  =  options .get ("frac_random" , 0.0 )
489+     if  not  0  <=  frac_random  <=  1 :
490+         raise  ValueError (
491+             f"frac_random must take on values in (0,1). Value: { frac_random }  " 
492+         )
493+ 
494+     batch_limit  =  options .get ("batch_limit" )
495+     num_optima  =  acq_function .optimal_inputs .shape [:- 1 ].numel ()
496+     suggestions  =  acq_function .optimal_inputs .reshape (num_optima , - 1 )
497+     X  =  torch .empty (0 , q , bounds .shape [1 ], dtype = bounds .dtype )
498+     num_random  =  round (raw_samples  *  frac_random )
499+     if  num_random  >  0 :
500+         X_rnd  =  sample_q_batches_from_polytope (
501+             n = num_random ,
502+             q = q ,
503+             bounds = bounds ,
504+             n_burnin = options .get ("n_burnin" , 10000 ),
505+             n_thinning = options .get ("n_thinning" , 32 ),
506+             equality_constraints = equality_constraints ,
507+             inequality_constraints = inequality_constraints ,
508+         )
509+         X  =  torch .cat ((X , X_rnd ))
510+ 
511+     if  num_random  <  raw_samples :
512+         X_perturbed  =  sample_points_around_best (
513+             acq_function = acq_function ,
514+             n_discrete_points = q  *  (raw_samples  -  num_random ),
515+             sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
516+             bounds = bounds ,
517+             best_X = suggestions ,
518+         )
519+         X_perturbed  =  X_perturbed .view (
520+             raw_samples  -  num_random , q , bounds .shape [- 1 ]
521+         ).cpu ()
522+         X  =  torch .cat ((X , X_perturbed ))
523+ 
524+     if  options .get ("sample_around_best" , False ):
525+         X_best  =  sample_points_around_best (
526+             acq_function = acq_function ,
527+             n_discrete_points = q  *  raw_samples ,
528+             sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
529+             bounds = bounds ,
530+         )
531+         X_best  =  X_best .view (raw_samples , q , bounds .shape [- 1 ]).cpu ()
532+         X  =  torch .cat ((X , X_best ))
533+ 
534+     with  torch .no_grad ():
535+         if  batch_limit  is  None :
536+             batch_limit  =  X .shape [0 ]
537+         # Evaluate the acquisition function on `X_rnd` using `batch_limit` 
538+         # sized chunks. 
539+         acq_vals  =  torch .cat (
540+             [
541+                 acq_function (x_ .to (device = device )).cpu ()
542+                 for  x_  in  X .split (split_size = batch_limit , dim = 0 )
543+             ],
544+             dim = 0 ,
545+         )
546+     idx  =  boltzmann_sample (
547+         function_values = acq_vals ,
548+         num_samples = num_restarts ,
549+         eta = options .get ("eta" , 2.0 ),
550+     )
551+     # set the respective initial conditions to the sampled optimizers 
552+     return  X [idx ]
553+ 
554+ 
471555def  gen_one_shot_kg_initial_conditions (
472556    acq_function : qKnowledgeGradient ,
473557    bounds : Tensor ,
@@ -602,59 +686,59 @@ def gen_one_shot_hvkg_initial_conditions(
602686) ->  Tensor  |  None :
603687    r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient. 
604688
605-     This function generates initial conditions for optimizing one-shot HVKG using 
606-     the hypervolume maximizing set (of fixed size) under the posterior mean. 
607-     Intutively, the hypervolume maximizing set of the fantasized posterior mean 
608-     will often be close to a hypervolume maximizing set under the current posterior 
609-     mean. This function uses that fact to generate the initial conditions 
610-     for the fantasy points. Specifically, a fraction of `1 - frac_random` (see 
611-     options) of the restarts are generated by learning the hypervolume maximizing sets 
612-     under the current posterior mean, where each hypervolume maximizing set is 
613-     obtained from maximizing the hypervolume from a different starting point. Given 
614-     a hypervolume maximizing set, the `q` candidate points are selected using to the 
615-     standard initialization strategy in `gen_batch_initial_conditions`, with the fixed 
616-     hypervolume maximizing set. The remaining `frac_random` restarts fantasy points 
617-     as well as all `q` candidate points are chosen according to the standard 
618-     initialization strategy in `gen_batch_initial_conditions`. 
619- 
620-     Args: 
621-         acq_function: The qKnowledgeGradient instance to be optimized. 
622-         bounds: A `2 x d` tensor of lower and upper bounds for each column of 
623-             task features. 
624-         q: The number of candidates to consider. 
625-         num_restarts: The number of starting points for multistart acquisition 
626-             function optimization. 
627-         raw_samples: The number of raw samples to consider in the initialization 
628-             heuristic. 
629-         fixed_features: A map `{feature_index: value}` for features that 
630-             should be fixed to a particular value during generation. 
631-         options: Options for initial condition generation. These contain all 
632-             settings for the standard heuristic initialization from 
633-             `gen_batch_initial_conditions`. In addition, they contain 
634-             `frac_random` (the fraction of fully random fantasy points), 
635-             `num_inner_restarts` and `raw_inner_samples` (the number of random 
636-             restarts and raw samples for solving the posterior objective 
637-             maximization problem, respectively) and `eta` (temperature parameter 
638-             for sampling heuristic from posterior objective maximizers). 
639-         inequality constraints: A list of tuples (indices, coefficients, rhs), 
640-             with each tuple encoding an inequality constraint of the form 
641-             `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. 
642-         equality constraints: A list of tuples (indices, coefficients, rhs), 
643-             with each tuple encoding an inequality constraint of the form 
644-             `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. 
645- 
646-     Returns: 
647-         A `num_restarts x q' x d` tensor that can be used as initial conditions 
648-         for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number 
649-         of points (candidate points plus fantasy points). 
650- 
651-     Example: 
652-         >>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point) 
653-         >>> bounds = torch.tensor([[0., 0.], [1., 1.]]) 
654-         >>> Xinit = gen_one_shot_hvkg_initial_conditions( 
655-         >>>     qHVKG, bounds, q=3, num_restarts=10, raw_samples=512, 
656-         >>>     options={"frac_random": 0.25}, 
657-         >>> ) 
689+          This function generates initial conditions for optimizing one-shot HVKG using 
690+          the hypervolume maximizing set (of fixed size) under the posterior mean. 
691+          Intutively, the hypervolume maximizing set of the fantasized posterior mean 
692+          will often be close to a hypervolume maximizing set under the current posterior 
693+          mean. This function uses that fact to generate the initial conditions 
694+          for the fantasy points. Specifically, a fraction of `1 - frac_random` (see 
695+          options) of the restarts are generated by learning the hypervolume maximizing sets 
696+          under the current posterior mean, where each hypervolume maximizing set is 
697+          obtained from maximizing the hypervolume from a different starting point. Given 
698+          a hypervolume maximizing set, the `q` candidate points are selected using to the 
699+          standard initialization strategy in `gen_batch_initial_conditions`, with the fixed 
700+          hypervolume maximizing set. The remaining `frac_random` restarts fantasy points 
701+          as well as all `q` candidate points are chosen according to the standard 
702+          initialization strategy in `gen_batch_initial_conditions`. 
703+ 
704+          Args: 
705+              acq_function: The qKnowledgeGradient instance to be optimized. 
706+              bounds: A `2 x d` tensor of lower and upper bounds for each column of 
707+                  task features. 
708+              q: The number of candidates to consider. 
709+              num_restarts: The number of starting points for multistart acquisition 
710+                  function optimization. 
711+              raw_samples: The number of raw samples to consider in the initialization 
712+                  heuristic. 
713+              fixed_features: A map `{feature_index: value}` for features that 
714+                  should be fixed to a particular value during generation. 
715+              options: Options for initial condition generation. These contain all 
716+                  settings for the standard heuristic initialization from 
717+                  `gen_batch_initial_conditions`. In addition, they contain 
718+                  `frac_random` (the fraction of fully random fantasy points), 
719+                  `num_inner_restarts` and `raw_inner_samples` (the number of random 
720+                  restarts and raw samples for solving the posterior objective 
721+                  maximization problem, respectively) and `eta` (temperature parameter 
722+                  for sampling heuristic from posterior objective maximizers). 
723+              inequality constraints: A list of tuples (indices, coefficients, rhs), 
724+                  with each tuple encoding an inequality constraint of the form 
725+                  `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. 
726+              equality constraints: A list of tuples (indices, coefficients, rhs), 
727+                  with each tuple encoding an inequality constraint of the form 
728+                  `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. 
729+ 
730+          Returns: 
731+              A `num_restarts x q' x d` tensor that can be used as initial conditions 
732+              for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number 
733+              of points (candidate points plus fantasy points). 
734+ 
735+     gen_batch_initial_conditions     Example: 
736+              >>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point) 
737+              >>> bounds = torch.tensor([[0., 0.], [1., 1.]]) 
738+              >>> Xinit = gen_one_shot_hvkg_initial_conditions( 
739+              >>>     qHVKG, bounds, q=3, num_restarts=10, raw_samples=512, 
740+              >>>     options={"frac_random": 0.25}, 
741+              >>> ) 
658742    """ 
659743    from  botorch .optim .optimize  import  optimize_acqf 
660744
@@ -1136,6 +1220,7 @@ def sample_points_around_best(
11361220    best_pct : float  =  5.0 ,
11371221    subset_sigma : float  =  1e-1 ,
11381222    prob_perturb : float  |  None  =  None ,
1223+     best_X : Tensor  |  None  =  None ,
11391224) ->  Tensor  |  None :
11401225    r"""Find best points and sample nearby points. 
11411226
@@ -1154,60 +1239,62 @@ def sample_points_around_best(
11541239        An optional `n_discrete_points x d`-dim tensor containing the 
11551240            sampled points. This is None if no baseline points are found. 
11561241    """ 
1157-     X  =  get_X_baseline (acq_function = acq_function )
1158-     if  X  is  None :
1159-         return 
1160-     with  torch .no_grad ():
1161-         try :
1162-             posterior  =  acq_function .model .posterior (X )
1163-         except  AttributeError :
1164-             warnings .warn (
1165-                 "Failed to sample around previous best points." ,
1166-                 BotorchWarning ,
1167-                 stacklevel = 3 ,
1168-             )
1242+     if  best_X  is  None :
1243+         X  =  get_X_baseline (acq_function = acq_function )
1244+         if  X  is  None :
11691245            return 
1170-         mean  =  posterior .mean 
1171-         while  mean .ndim  >  2 :
1172-             # take average over batch dims 
1173-             mean  =  mean .mean (dim = 0 )
1174-         try :
1175-             f_pred  =  acq_function .objective (mean )
1176-         # Some acquisition functions do not have an objective 
1177-         # and for some acquisition functions the objective is None 
1178-         except  (AttributeError , TypeError ):
1179-             f_pred  =  mean 
1180-         if  hasattr (acq_function , "maximize" ):
1181-             # make sure that the optimiztaion direction is set properly 
1182-             if  not  acq_function .maximize :
1183-                 f_pred  =  - f_pred 
1184-         try :
1185-             # handle constraints for EHVI-based acquisition functions 
1186-             constraints  =  acq_function .constraints 
1187-             if  constraints  is  not   None :
1188-                 neg_violation  =  - torch .stack (
1189-                     [c (mean ).clamp_min (0.0 ) for  c  in  constraints ], dim = - 1 
1190-                 ).sum (dim = - 1 )
1191-                 feas  =  neg_violation  ==  0 
1192-                 if  feas .any ():
1193-                     f_pred [~ feas ] =  float ("-inf" )
1194-                 else :
1195-                     # set objective equal to negative violation 
1196-                     f_pred  =  neg_violation 
1197-         except  AttributeError :
1198-             pass 
1199-         if  f_pred .ndim  ==  mean .ndim  and  f_pred .shape [- 1 ] >  1 :
1200-             # multi-objective 
1201-             # find pareto set 
1202-             is_pareto  =  is_non_dominated (f_pred )
1203-             best_X  =  X [is_pareto ]
1204-         else :
1205-             if  f_pred .shape [- 1 ] ==  1 :
1206-                 f_pred  =  f_pred .squeeze (- 1 )
1207-             n_best  =  max (1 , round (X .shape [0 ] *  best_pct  /  100 ))
1208-             # the view() is to ensure that best_idcs is not a scalar tensor 
1209-             best_idcs  =  torch .topk (f_pred , n_best ).indices .view (- 1 )
1210-             best_X  =  X [best_idcs ]
1246+         with  torch .no_grad ():
1247+             try :
1248+                 posterior  =  acq_function .model .posterior (X )
1249+             except  AttributeError :
1250+                 warnings .warn (
1251+                     "Failed to sample around previous best points." ,
1252+                     BotorchWarning ,
1253+                     stacklevel = 3 ,
1254+                 )
1255+                 return 
1256+             mean  =  posterior .mean 
1257+             while  mean .ndim  >  2 :
1258+                 # take average over batch dims 
1259+                 mean  =  mean .mean (dim = 0 )
1260+             try :
1261+                 f_pred  =  acq_function .objective (mean )
1262+             # Some acquisition functions do not have an objective 
1263+             # and for some acquisition functions the objective is None 
1264+             except  (AttributeError , TypeError ):
1265+                 f_pred  =  mean 
1266+             if  hasattr (acq_function , "maximize" ):
1267+                 # make sure that the optimiztaion direction is set properly 
1268+                 if  not  acq_function .maximize :
1269+                     f_pred  =  - f_pred 
1270+             try :
1271+                 # handle constraints for EHVI-based acquisition functions 
1272+                 constraints  =  acq_function .constraints 
1273+                 if  constraints  is  not   None :
1274+                     neg_violation  =  - torch .stack (
1275+                         [c (mean ).clamp_min (0.0 ) for  c  in  constraints ], dim = - 1 
1276+                     ).sum (dim = - 1 )
1277+                     feas  =  neg_violation  ==  0 
1278+                     if  feas .any ():
1279+                         f_pred [~ feas ] =  float ("-inf" )
1280+                     else :
1281+                         # set objective equal to negative violation 
1282+                         f_pred  =  neg_violation 
1283+             except  AttributeError :
1284+                 pass 
1285+             if  f_pred .ndim  ==  mean .ndim  and  f_pred .shape [- 1 ] >  1 :
1286+                 # multi-objective 
1287+                 # find pareto set 
1288+                 is_pareto  =  is_non_dominated (f_pred )
1289+                 best_X  =  X [is_pareto ]
1290+             else :
1291+                 if  f_pred .shape [- 1 ] ==  1 :
1292+                     f_pred  =  f_pred .squeeze (- 1 )
1293+                 n_best  =  max (1 , round (X .shape [0 ] *  best_pct  /  100 ))
1294+                 # the view() is to ensure that best_idcs is not a scalar tensor 
1295+                 best_idcs  =  torch .topk (f_pred , n_best ).indices .view (- 1 )
1296+                 best_X  =  X [best_idcs ]
1297+ 
12111298    use_perturbed_sampling  =  best_X .shape [- 1 ] >=  20  or  prob_perturb  is  not   None 
12121299    n_trunc_normal_points  =  (
12131300        n_discrete_points  //  2  if  use_perturbed_sampling  else  n_discrete_points 
0 commit comments