Skip to content

Commit ebe03af

Browse files
authored
Merge branch 'main' into main
2 parents 7fe9237 + 0290e3d commit ebe03af

File tree

2 files changed

+106
-67
lines changed

2 files changed

+106
-67
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import itertools
99
import random
1010
import warnings
11-
from typing import Any, Callable, Sequence
11+
from collections.abc import Mapping, Sequence
12+
from typing import Any, Callable
1213

1314
import torch
1415
from botorch.acquisition import AcquisitionFunction
@@ -215,8 +216,7 @@ def get_nearest_neighbors(
215216

216217
def get_categorical_neighbors(
217218
current_x: Tensor,
218-
bounds: Tensor,
219-
cat_dims: Tensor,
219+
cat_dims: dict[int, list[float]],
220220
max_num_cat_values: int = MAX_DISCRETE_VALUES,
221221
) -> Tensor:
222222
r"""Generate all 1-Hamming distance neighbors of a given input. The neighbors
@@ -231,8 +231,8 @@ def get_categorical_neighbors(
231231
232232
Args:
233233
current_x: The design to find the neighbors of. A tensor of shape `d`.
234-
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
235-
cat_dims: A tensor of indices corresponding to categorical parameters.
234+
cat_dims: A dictionary mapping indices of categorical dimensions
235+
to a list of allowed values for that dimension.
236236
max_num_cat_values: Maximum number of values for a categorical parameter,
237237
beyond which values are uniformly sampled.
238238
@@ -246,31 +246,31 @@ def get_categorical_neighbors(
246246
def _get_cat_values(dim: int) -> Sequence[int]:
247247
r"""Get a sequence of up to `max_num_cat_values` values that a categorical
248248
feature may take."""
249-
lb, ub = bounds[:, dim].long()
250249
current_value = current_x[dim]
251-
cat_values = range(lb, ub + 1)
252-
if ub - lb + 1 <= max_num_cat_values:
253-
return cat_values
250+
if len(cat_dims[dim]) <= max_num_cat_values:
251+
return cat_dims[dim]
254252
else:
255253
return random.sample(
256-
[v for v in cat_values if v != current_value], k=max_num_cat_values
254+
[v for v in cat_dims[dim] if v != current_value], k=max_num_cat_values
257255
)
258256

257+
new_cat_values_dict = {dim: _get_cat_values(dim) for dim in cat_dims.keys()}
259258
new_cat_values_lst = list(
260-
itertools.chain.from_iterable(_get_cat_values(dim) for dim in cat_dims)
259+
itertools.chain.from_iterable(new_cat_values_dict.values())
261260
)
262261
new_cat_values = torch.tensor(
263262
new_cat_values_lst, device=current_x.device, dtype=current_x.dtype
264263
)
265264

266-
num_cat_values = (bounds[1, :] - bounds[0, :] + 1).to(dtype=torch.long)
267-
num_cat_values.clamp_(max=max_num_cat_values)
268265
new_cat_idcs = torch.cat(
269266
tuple(
270-
torch.full((num_cat_values[dim].item(),), dim, device=current_x.device)
271-
for dim in cat_dims
267+
torch.full(
268+
(min(len(values), max_num_cat_values),), dim, device=current_x.device
269+
)
270+
for dim, values in new_cat_values_dict.items()
272271
)
273272
)
273+
274274
neighbors = current_x.repeat(len(new_cat_values), 1)
275275
# Assign the new values to their corresponding columns.
276276
neighbors.scatter_(1, new_cat_idcs.view(-1, 1), new_cat_values.view(-1, 1))
@@ -285,7 +285,7 @@ def get_spray_points(
285285
X_baseline: Tensor,
286286
cont_dims: Tensor,
287287
discrete_dims: dict[int, list[float]],
288-
cat_dims: Tensor,
288+
cat_dims: dict[int, list[float]],
289289
bounds: Tensor,
290290
num_spray_points: int,
291291
std_cont_perturbation: float = STD_CONT_PERTURBATION,
@@ -301,7 +301,8 @@ def get_spray_points(
301301
cont_dims: Indices of continuous parameters/input dimensions.
302302
discrete_dims: A dictionary mapping indices of discrete dimensions
303303
to a list of allowed values for that dimension.
304-
cat_dims: Indices of categorical parameters/input dimensions.
304+
cat_dims: A dictionary mapping indices of categorical dimensions
305+
to a list of allowed values for that dimension.
305306
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
306307
num_spray_points: Number of spray points to return.
307308
std_cont_perturbation: standard deviation of Normal perturbations of
@@ -316,6 +317,7 @@ def get_spray_points(
316317
t_discrete_dims = torch.tensor(
317318
list(discrete_dims.keys()), dtype=torch.long, device=device
318319
)
320+
t_cat_dims = torch.tensor(list(cat_dims.keys()), dtype=torch.long, device=device)
319321
for x in X_baseline:
320322
if len(discrete_dims) > 0:
321323
discrete_perturbs = get_nearest_neighbors(
@@ -326,10 +328,8 @@ def get_spray_points(
326328
len(discrete_perturbs), (num_spray_points,), device=device
327329
)
328330
]
329-
if cat_dims.numel():
330-
cat_perturbs = get_categorical_neighbors(
331-
current_x=x, bounds=bounds, cat_dims=cat_dims
332-
)
331+
if len(cat_dims) > 0:
332+
cat_perturbs = get_categorical_neighbors(current_x=x, cat_dims=cat_dims)
333333
cat_perturbs = cat_perturbs[
334334
torch.randint(len(cat_perturbs), (num_spray_points,), device=device)
335335
]
@@ -343,8 +343,8 @@ def get_spray_points(
343343
nbds = torch.zeros(num_spray_points, dim, device=device, dtype=dtype)
344344
if len(discrete_dims) > 0:
345345
nbds[..., t_discrete_dims] = discrete_perturbs[..., t_discrete_dims]
346-
if cat_dims.numel():
347-
nbds[..., cat_dims] = cat_perturbs[..., cat_dims]
346+
if len(cat_dims) > 0:
347+
nbds[..., t_cat_dims] = cat_perturbs[..., t_cat_dims]
348348

349349
nbds[..., cont_dims] = cont_perturbs
350350
perturb_nbors = torch.cat([perturb_nbors, nbds], dim=0)
@@ -354,7 +354,7 @@ def get_spray_points(
354354
def sample_feasible_points(
355355
opt_inputs: OptimizeAcqfInputs,
356356
discrete_dims: dict[int, list[float]],
357-
cat_dims: Tensor,
357+
cat_dims: dict[int, list[float]],
358358
num_points: int,
359359
) -> Tensor:
360360
r"""Sample feasible points from the optimization domain.
@@ -374,7 +374,8 @@ def sample_feasible_points(
374374
opt_inputs: Common set of arguments for acquisition optimization.
375375
discrete_dims: A dictionary mapping indices of discrete dimensions
376376
to a list of allowed values for that dimension.
377-
cat_dims: A tensor of indices corresponding to categorical parameters.
377+
cat_dims: A dictionary mapping indices of categorical dimensions
378+
to a list of allowed values for that dimension.
378379
num_points: The number of points to sample.
379380
380381
Returns:
@@ -413,7 +414,7 @@ def generator(n: int) -> Tensor:
413414
base_points = generator(n=num_remaining * 2)
414415
# Round the discrete dimensions to the nearest integer.
415416
base_points = round_discrete_dims(X=base_points, discrete_dims=discrete_dims)
416-
base_points[:, cat_dims] = base_points[:, cat_dims].round()
417+
base_points = round_discrete_dims(X=base_points, discrete_dims=cat_dims)
417418
# Fix the fixed features.
418419
base_points = fix_features(
419420
X=base_points,
@@ -457,7 +458,7 @@ def round_discrete_dims(X: Tensor, discrete_dims: dict[int, list[float]]) -> Ten
457458
def generate_starting_points(
458459
opt_inputs: OptimizeAcqfInputs,
459460
discrete_dims: dict[int, list[float]],
460-
cat_dims: Tensor,
461+
cat_dims: dict[int, list[float]],
461462
cont_dims: Tensor,
462463
) -> tuple[Tensor, Tensor]:
463464
"""Generate initial starting points for the alternating optimization.
@@ -472,7 +473,8 @@ def generate_starting_points(
472473
from `opt_inputs`.
473474
discrete_dims: A dictionary mapping indices of discrete dimensions
474475
to a list of allowed values for that dimension.
475-
cat_dims: A tensor of indices corresponding to categorical parameters.
476+
cat_dims: A dictionary mapping indices of categorical dimensions
477+
to a list of allowed values for that dimension.
476478
cont_dims: A tensor of indices corresponding to continuous parameters.
477479
478480
Returns:
@@ -625,7 +627,7 @@ def generate_starting_points(
625627
def discrete_step(
626628
opt_inputs: OptimizeAcqfInputs,
627629
discrete_dims: dict[int, list[float]],
628-
cat_dims: Tensor,
630+
cat_dims: dict[int, list[float]],
629631
current_x: Tensor,
630632
) -> tuple[Tensor, Tensor]:
631633
"""Discrete nearest neighbour search.
@@ -636,7 +638,8 @@ def discrete_step(
636638
and constraints from `opt_inputs`.
637639
discrete_dims: A dictionary mapping indices of discrete dimensions
638640
to a list of allowed values for that dimension.
639-
cat_dims: A tensor of indices corresponding to categorical parameters.
641+
cat_dims: A dictionary mapping indices of categorical dimensions
642+
to a list of allowed values for that dimension.
640643
current_x: Batch of starting points. A tensor of shape `b x d`.
641644
642645
Returns:
@@ -676,10 +679,9 @@ def discrete_step(
676679
neighbors.append(x_neighbors_discrete)
677680

678681
# if we have cat_dims look for neighbors by changing the cat's
679-
if cat_dims.numel():
682+
if len(cat_dims) > 0:
680683
x_neighbors_cat = get_categorical_neighbors(
681684
current_x=current_x[i].detach(),
682-
bounds=opt_inputs.bounds,
683685
cat_dims=cat_dims,
684686
)
685687
x_neighbors_cat = _filter_infeasible(
@@ -806,8 +808,8 @@ def continuous_step(
806808
def optimize_acqf_mixed_alternating(
807809
acq_function: AcquisitionFunction,
808810
bounds: Tensor,
809-
discrete_dims: dict[int, list[float]] | None = None,
810-
cat_dims: list[int] | None = None,
811+
discrete_dims: Mapping[int, Sequence[float]] | None = None,
812+
cat_dims: Mapping[int, Sequence[float]] | None = None,
811813
options: dict[str, Any] | None = None,
812814
q: int = 1,
813815
raw_samples: int = RAW_SAMPLES,
@@ -837,7 +839,8 @@ def optimize_acqf_mixed_alternating(
837839
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
838840
discrete_dims: A dictionary mapping indices of discrete and binary
839841
dimensions to a list of allowed values for that dimension.
840-
cat_dims: A list of indices corresponding to categorical parameters.
842+
cat_dims: A dictionary mapping indices of categorical dimensions
843+
to a list of allowed values for that dimension.
841844
options: Dictionary specifying optimization options. Supports the following:
842845
- "initialization_strategy": Strategy used to generate the initial candidates.
843846
"random", "continuous_relaxation" or "equally_spaced" (linspace style).
@@ -891,12 +894,15 @@ def optimize_acqf_mixed_alternating(
891894
"sequential optimization."
892895
)
893896

894-
cat_dims = cat_dims or []
897+
cat_dims = cat_dims or {}
895898
discrete_dims = discrete_dims or {}
896899

897900
# sort the values in discrete dims in ascending order
898901
discrete_dims = {dim: sorted(values) for dim, values in discrete_dims.items()}
899902

903+
# sort the categorical dims in ascending order
904+
cat_dims = {dim: sorted(values) for dim, values in cat_dims.items()}
905+
900906
for dim, values in discrete_dims.items():
901907
lower_bnd, upper_bnd = bounds[:, dim].tolist()
902908
lower, upper = values[0], values[-1]
@@ -972,8 +978,10 @@ def optimize_acqf_mixed_alternating(
972978
for dim, values in discrete_dims.items()
973979
if dim not in fixed_features
974980
}
975-
cat_dims = [dim for dim in cat_dims if dim not in fixed_features]
976-
non_cont_dims = [*discrete_dims.keys(), *cat_dims]
981+
cat_dims = {
982+
dim: values for dim, values in cat_dims.items() if dim not in fixed_features
983+
}
984+
non_cont_dims = [*discrete_dims.keys(), *cat_dims.keys()]
977985
if len(non_cont_dims) == 0:
978986
# If the problem is fully continuous, fall back to standard optimization.
979987
return _optimize_acqf(
@@ -989,13 +997,15 @@ def optimize_acqf_mixed_alternating(
989997
and max(non_cont_dims) <= dim - 1
990998
):
991999
raise ValueError(
992-
"`discrete_dims` and `cat_dims` must be lists with unique, disjoint "
993-
"integers between 0 and num_dims - 1."
1000+
"`discrete_dims` and `cat_dims` must be dictionaries with unique, disjoint "
1001+
"integers as keys between 0 and num_dims - 1."
9941002
)
9951003
discrete_dims_t = torch.tensor(
9961004
list(discrete_dims.keys()), dtype=torch.long, device=tkwargs["device"]
9971005
)
998-
cat_dims_t = torch.tensor(cat_dims, dtype=torch.long, device=tkwargs["device"])
1006+
cat_dims_t = torch.tensor(
1007+
list(cat_dims.keys()), dtype=torch.long, device=tkwargs["device"]
1008+
)
9991009
non_cont_dims = torch.tensor(
10001010
non_cont_dims, dtype=torch.long, device=tkwargs["device"]
10011011
)
@@ -1011,7 +1021,7 @@ def optimize_acqf_mixed_alternating(
10111021
best_X, best_acq_val = generate_starting_points(
10121022
opt_inputs=opt_inputs,
10131023
discrete_dims=discrete_dims,
1014-
cat_dims=cat_dims_t,
1024+
cat_dims=cat_dims,
10151025
cont_dims=cont_dims,
10161026
)
10171027

@@ -1021,7 +1031,7 @@ def optimize_acqf_mixed_alternating(
10211031
best_X[~done], best_acq_val[~done] = discrete_step(
10221032
opt_inputs=opt_inputs,
10231033
discrete_dims=discrete_dims,
1024-
cat_dims=cat_dims_t,
1034+
cat_dims=cat_dims,
10251035
current_x=best_X[~done],
10261036
)
10271037

0 commit comments

Comments
 (0)