88import itertools
99import random
1010import warnings
11- from typing import Any , Callable , Sequence
11+ from collections .abc import Mapping , Sequence
12+ from typing import Any , Callable
1213
1314import torch
1415from botorch .acquisition import AcquisitionFunction
@@ -215,8 +216,7 @@ def get_nearest_neighbors(
215216
216217def get_categorical_neighbors (
217218 current_x : Tensor ,
218- bounds : Tensor ,
219- cat_dims : Tensor ,
219+ cat_dims : dict [int , list [float ]],
220220 max_num_cat_values : int = MAX_DISCRETE_VALUES ,
221221) -> Tensor :
222222 r"""Generate all 1-Hamming distance neighbors of a given input. The neighbors
@@ -231,8 +231,8 @@ def get_categorical_neighbors(
231231
232232 Args:
233233 current_x: The design to find the neighbors of. A tensor of shape `d`.
234- bounds : A `2 x d` tensor of lower and upper bounds for each column of `X`.
235- cat_dims: A tensor of indices corresponding to categorical parameters .
234+ cat_dims : A dictionary mapping indices of categorical dimensions
235+ to a list of allowed values for that dimension .
236236 max_num_cat_values: Maximum number of values for a categorical parameter,
237237 beyond which values are uniformly sampled.
238238
@@ -246,31 +246,31 @@ def get_categorical_neighbors(
246246 def _get_cat_values (dim : int ) -> Sequence [int ]:
247247 r"""Get a sequence of up to `max_num_cat_values` values that a categorical
248248 feature may take."""
249- lb , ub = bounds [:, dim ].long ()
250249 current_value = current_x [dim ]
251- cat_values = range (lb , ub + 1 )
252- if ub - lb + 1 <= max_num_cat_values :
253- return cat_values
250+ if len (cat_dims [dim ]) <= max_num_cat_values :
251+ return cat_dims [dim ]
254252 else :
255253 return random .sample (
256- [v for v in cat_values if v != current_value ], k = max_num_cat_values
254+ [v for v in cat_dims [ dim ] if v != current_value ], k = max_num_cat_values
257255 )
258256
257+ new_cat_values_dict = {dim : _get_cat_values (dim ) for dim in cat_dims .keys ()}
259258 new_cat_values_lst = list (
260- itertools .chain .from_iterable (_get_cat_values ( dim ) for dim in cat_dims )
259+ itertools .chain .from_iterable (new_cat_values_dict . values () )
261260 )
262261 new_cat_values = torch .tensor (
263262 new_cat_values_lst , device = current_x .device , dtype = current_x .dtype
264263 )
265264
266- num_cat_values = (bounds [1 , :] - bounds [0 , :] + 1 ).to (dtype = torch .long )
267- num_cat_values .clamp_ (max = max_num_cat_values )
268265 new_cat_idcs = torch .cat (
269266 tuple (
270- torch .full ((num_cat_values [dim ].item (),), dim , device = current_x .device )
271- for dim in cat_dims
267+ torch .full (
268+ (min (len (values ), max_num_cat_values ),), dim , device = current_x .device
269+ )
270+ for dim , values in new_cat_values_dict .items ()
272271 )
273272 )
273+
274274 neighbors = current_x .repeat (len (new_cat_values ), 1 )
275275 # Assign the new values to their corresponding columns.
276276 neighbors .scatter_ (1 , new_cat_idcs .view (- 1 , 1 ), new_cat_values .view (- 1 , 1 ))
@@ -285,7 +285,7 @@ def get_spray_points(
285285 X_baseline : Tensor ,
286286 cont_dims : Tensor ,
287287 discrete_dims : dict [int , list [float ]],
288- cat_dims : Tensor ,
288+ cat_dims : dict [ int , list [ float ]] ,
289289 bounds : Tensor ,
290290 num_spray_points : int ,
291291 std_cont_perturbation : float = STD_CONT_PERTURBATION ,
@@ -301,7 +301,8 @@ def get_spray_points(
301301 cont_dims: Indices of continuous parameters/input dimensions.
302302 discrete_dims: A dictionary mapping indices of discrete dimensions
303303 to a list of allowed values for that dimension.
304- cat_dims: Indices of categorical parameters/input dimensions.
304+ cat_dims: A dictionary mapping indices of categorical dimensions
305+ to a list of allowed values for that dimension.
305306 bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
306307 num_spray_points: Number of spray points to return.
307308 std_cont_perturbation: standard deviation of Normal perturbations of
@@ -316,6 +317,7 @@ def get_spray_points(
316317 t_discrete_dims = torch .tensor (
317318 list (discrete_dims .keys ()), dtype = torch .long , device = device
318319 )
320+ t_cat_dims = torch .tensor (list (cat_dims .keys ()), dtype = torch .long , device = device )
319321 for x in X_baseline :
320322 if len (discrete_dims ) > 0 :
321323 discrete_perturbs = get_nearest_neighbors (
@@ -326,10 +328,8 @@ def get_spray_points(
326328 len (discrete_perturbs ), (num_spray_points ,), device = device
327329 )
328330 ]
329- if cat_dims .numel ():
330- cat_perturbs = get_categorical_neighbors (
331- current_x = x , bounds = bounds , cat_dims = cat_dims
332- )
331+ if len (cat_dims ) > 0 :
332+ cat_perturbs = get_categorical_neighbors (current_x = x , cat_dims = cat_dims )
333333 cat_perturbs = cat_perturbs [
334334 torch .randint (len (cat_perturbs ), (num_spray_points ,), device = device )
335335 ]
@@ -343,8 +343,8 @@ def get_spray_points(
343343 nbds = torch .zeros (num_spray_points , dim , device = device , dtype = dtype )
344344 if len (discrete_dims ) > 0 :
345345 nbds [..., t_discrete_dims ] = discrete_perturbs [..., t_discrete_dims ]
346- if cat_dims . numel () :
347- nbds [..., cat_dims ] = cat_perturbs [..., cat_dims ]
346+ if len ( cat_dims ) > 0 :
347+ nbds [..., t_cat_dims ] = cat_perturbs [..., t_cat_dims ]
348348
349349 nbds [..., cont_dims ] = cont_perturbs
350350 perturb_nbors = torch .cat ([perturb_nbors , nbds ], dim = 0 )
@@ -354,7 +354,7 @@ def get_spray_points(
354354def sample_feasible_points (
355355 opt_inputs : OptimizeAcqfInputs ,
356356 discrete_dims : dict [int , list [float ]],
357- cat_dims : Tensor ,
357+ cat_dims : dict [ int , list [ float ]] ,
358358 num_points : int ,
359359) -> Tensor :
360360 r"""Sample feasible points from the optimization domain.
@@ -374,7 +374,8 @@ def sample_feasible_points(
374374 opt_inputs: Common set of arguments for acquisition optimization.
375375 discrete_dims: A dictionary mapping indices of discrete dimensions
376376 to a list of allowed values for that dimension.
377- cat_dims: A tensor of indices corresponding to categorical parameters.
377+ cat_dims: A dictionary mapping indices of categorical dimensions
378+ to a list of allowed values for that dimension.
378379 num_points: The number of points to sample.
379380
380381 Returns:
@@ -413,7 +414,7 @@ def generator(n: int) -> Tensor:
413414 base_points = generator (n = num_remaining * 2 )
414415 # Round the discrete dimensions to the nearest integer.
415416 base_points = round_discrete_dims (X = base_points , discrete_dims = discrete_dims )
416- base_points [:, cat_dims ] = base_points [:, cat_dims ]. round ( )
417+ base_points = round_discrete_dims ( X = base_points , discrete_dims = cat_dims )
417418 # Fix the fixed features.
418419 base_points = fix_features (
419420 X = base_points ,
@@ -457,7 +458,7 @@ def round_discrete_dims(X: Tensor, discrete_dims: dict[int, list[float]]) -> Ten
457458def generate_starting_points (
458459 opt_inputs : OptimizeAcqfInputs ,
459460 discrete_dims : dict [int , list [float ]],
460- cat_dims : Tensor ,
461+ cat_dims : dict [ int , list [ float ]] ,
461462 cont_dims : Tensor ,
462463) -> tuple [Tensor , Tensor ]:
463464 """Generate initial starting points for the alternating optimization.
@@ -472,7 +473,8 @@ def generate_starting_points(
472473 from `opt_inputs`.
473474 discrete_dims: A dictionary mapping indices of discrete dimensions
474475 to a list of allowed values for that dimension.
475- cat_dims: A tensor of indices corresponding to categorical parameters.
476+ cat_dims: A dictionary mapping indices of categorical dimensions
477+ to a list of allowed values for that dimension.
476478 cont_dims: A tensor of indices corresponding to continuous parameters.
477479
478480 Returns:
@@ -625,7 +627,7 @@ def generate_starting_points(
625627def discrete_step (
626628 opt_inputs : OptimizeAcqfInputs ,
627629 discrete_dims : dict [int , list [float ]],
628- cat_dims : Tensor ,
630+ cat_dims : dict [ int , list [ float ]] ,
629631 current_x : Tensor ,
630632) -> tuple [Tensor , Tensor ]:
631633 """Discrete nearest neighbour search.
@@ -636,7 +638,8 @@ def discrete_step(
636638 and constraints from `opt_inputs`.
637639 discrete_dims: A dictionary mapping indices of discrete dimensions
638640 to a list of allowed values for that dimension.
639- cat_dims: A tensor of indices corresponding to categorical parameters.
641+ cat_dims: A dictionary mapping indices of categorical dimensions
642+ to a list of allowed values for that dimension.
640643 current_x: Batch of starting points. A tensor of shape `b x d`.
641644
642645 Returns:
@@ -676,10 +679,9 @@ def discrete_step(
676679 neighbors .append (x_neighbors_discrete )
677680
678681 # if we have cat_dims look for neighbors by changing the cat's
679- if cat_dims . numel () :
682+ if len ( cat_dims ) > 0 :
680683 x_neighbors_cat = get_categorical_neighbors (
681684 current_x = current_x [i ].detach (),
682- bounds = opt_inputs .bounds ,
683685 cat_dims = cat_dims ,
684686 )
685687 x_neighbors_cat = _filter_infeasible (
@@ -806,8 +808,8 @@ def continuous_step(
806808def optimize_acqf_mixed_alternating (
807809 acq_function : AcquisitionFunction ,
808810 bounds : Tensor ,
809- discrete_dims : dict [int , list [float ]] | None = None ,
810- cat_dims : list [int ] | None = None ,
811+ discrete_dims : Mapping [int , Sequence [float ]] | None = None ,
812+ cat_dims : Mapping [int , Sequence [ float ] ] | None = None ,
811813 options : dict [str , Any ] | None = None ,
812814 q : int = 1 ,
813815 raw_samples : int = RAW_SAMPLES ,
@@ -837,7 +839,8 @@ def optimize_acqf_mixed_alternating(
837839 bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
838840 discrete_dims: A dictionary mapping indices of discrete and binary
839841 dimensions to a list of allowed values for that dimension.
840- cat_dims: A list of indices corresponding to categorical parameters.
842+ cat_dims: A dictionary mapping indices of categorical dimensions
843+ to a list of allowed values for that dimension.
841844 options: Dictionary specifying optimization options. Supports the following:
842845 - "initialization_strategy": Strategy used to generate the initial candidates.
843846 "random", "continuous_relaxation" or "equally_spaced" (linspace style).
@@ -891,12 +894,15 @@ def optimize_acqf_mixed_alternating(
891894 "sequential optimization."
892895 )
893896
894- cat_dims = cat_dims or []
897+ cat_dims = cat_dims or {}
895898 discrete_dims = discrete_dims or {}
896899
897900 # sort the values in discrete dims in ascending order
898901 discrete_dims = {dim : sorted (values ) for dim , values in discrete_dims .items ()}
899902
903+ # sort the categorical dims in ascending order
904+ cat_dims = {dim : sorted (values ) for dim , values in cat_dims .items ()}
905+
900906 for dim , values in discrete_dims .items ():
901907 lower_bnd , upper_bnd = bounds [:, dim ].tolist ()
902908 lower , upper = values [0 ], values [- 1 ]
@@ -972,8 +978,10 @@ def optimize_acqf_mixed_alternating(
972978 for dim , values in discrete_dims .items ()
973979 if dim not in fixed_features
974980 }
975- cat_dims = [dim for dim in cat_dims if dim not in fixed_features ]
976- non_cont_dims = [* discrete_dims .keys (), * cat_dims ]
981+ cat_dims = {
982+ dim : values for dim , values in cat_dims .items () if dim not in fixed_features
983+ }
984+ non_cont_dims = [* discrete_dims .keys (), * cat_dims .keys ()]
977985 if len (non_cont_dims ) == 0 :
978986 # If the problem is fully continuous, fall back to standard optimization.
979987 return _optimize_acqf (
@@ -989,13 +997,15 @@ def optimize_acqf_mixed_alternating(
989997 and max (non_cont_dims ) <= dim - 1
990998 ):
991999 raise ValueError (
992- "`discrete_dims` and `cat_dims` must be lists with unique, disjoint "
993- "integers between 0 and num_dims - 1."
1000+ "`discrete_dims` and `cat_dims` must be dictionaries with unique, disjoint "
1001+ "integers as keys between 0 and num_dims - 1."
9941002 )
9951003 discrete_dims_t = torch .tensor (
9961004 list (discrete_dims .keys ()), dtype = torch .long , device = tkwargs ["device" ]
9971005 )
998- cat_dims_t = torch .tensor (cat_dims , dtype = torch .long , device = tkwargs ["device" ])
1006+ cat_dims_t = torch .tensor (
1007+ list (cat_dims .keys ()), dtype = torch .long , device = tkwargs ["device" ]
1008+ )
9991009 non_cont_dims = torch .tensor (
10001010 non_cont_dims , dtype = torch .long , device = tkwargs ["device" ]
10011011 )
@@ -1011,7 +1021,7 @@ def optimize_acqf_mixed_alternating(
10111021 best_X , best_acq_val = generate_starting_points (
10121022 opt_inputs = opt_inputs ,
10131023 discrete_dims = discrete_dims ,
1014- cat_dims = cat_dims_t ,
1024+ cat_dims = cat_dims ,
10151025 cont_dims = cont_dims ,
10161026 )
10171027
@@ -1021,7 +1031,7 @@ def optimize_acqf_mixed_alternating(
10211031 best_X [~ done ], best_acq_val [~ done ] = discrete_step (
10221032 opt_inputs = opt_inputs ,
10231033 discrete_dims = discrete_dims ,
1024- cat_dims = cat_dims_t ,
1034+ cat_dims = cat_dims ,
10251035 current_x = best_X [~ done ],
10261036 )
10271037
0 commit comments