Skip to content

Commit c2ed4d6

Browse files
author
Sahran Ashoor
committed
Updated optim + model files in respect to upstream
1 parent 04ae7c4 commit c2ed4d6

File tree

4 files changed

+190
-59
lines changed

4 files changed

+190
-59
lines changed

botorch/models/fully_bayesian_multitask.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
reshape_and_detach,
2020
SaasPyroModel,
2121
)
22+
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
2223
from botorch.models.multitask import MultiTaskGP
2324
from botorch.models.transforms.input import InputTransform
2425
from botorch.models.transforms.outcome import OutcomeTransform
2526
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
2627
from gpytorch.distributions import MultivariateNormal
27-
from gpytorch.kernels import IndexKernel, MaternKernel
28+
from gpytorch.kernels import MaternKernel
29+
from gpytorch.kernels.index_kernel import IndexKernel
2830
from gpytorch.kernels.kernel import Kernel
2931
from gpytorch.likelihoods.likelihood import Likelihood
3032
from gpytorch.means.mean import Mean
@@ -137,7 +139,7 @@ def sample_task_lengthscale(
137139

138140
def load_mcmc_samples(
139141
self, mcmc_samples: dict[str, Tensor]
140-
) -> tuple[Mean, Kernel, Likelihood, Kernel, Parameter]:
142+
) -> tuple[Mean, Kernel, Likelihood, Kernel]:
141143
r"""Load the MCMC samples into the mean_module, covar_module, and likelihood."""
142144
tkwargs = {"device": self.train_X.device, "dtype": self.train_X.dtype}
143145
num_mcmc_samples = len(mcmc_samples["mean"])
@@ -406,30 +408,7 @@ def posterior(
406408

407409
def forward(self, X: Tensor) -> MultivariateNormal:
408410
self._check_if_fitted()
409-
x_basic, task_idcs = self._split_inputs(X)
410-
411-
mean_x = self.mean_module(x_basic)
412-
covar_x = self.covar_module(x_basic)
413-
414-
tsub_idcs = task_idcs.squeeze(-1)
415-
if tsub_idcs.ndim > 1:
416-
tsub_idcs = tsub_idcs.squeeze(-2)
417-
latent_features = self.latent_features[:, tsub_idcs, :]
418-
419-
if X.ndim > 3:
420-
# batch eval mode
421-
# for X (batch_shape x num_samples x q x d), task_idcs[:,i,:,] are the same
422-
# reshape X to (batch_shape x num_samples x q x d)
423-
latent_features = latent_features.permute(
424-
[-i for i in range(X.ndim - 1, 2, -1)]
425-
+ [0]
426-
+ [-i for i in range(2, 0, -1)]
427-
)
428-
429-
# Combine the two in an ICM fashion
430-
covar_i = self.task_covar_module(latent_features)
431-
covar = covar_x.mul(covar_i)
432-
return MultivariateNormal(mean_x, covar)
411+
return super().forward(X)
433412

434413
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
435414
r"""Custom logic for loading the state dict.
@@ -474,3 +453,37 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
474453
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
475454
# Load the actual samples from the state dict
476455
super().load_state_dict(state_dict=state_dict, strict=strict)
456+
457+
def condition_on_observations(
458+
self, X: Tensor, Y: Tensor, **kwargs: Any
459+
) -> BatchedMultiOutputGPyTorchModel:
460+
"""Conditions on additional observations for a Fully Bayesian model (either
461+
identical across models or unique per-model).
462+
463+
Args:
464+
X: A `batch_shape x num_samples x d`-dim Tensor, where `d` is
465+
the dimension of the feature space and `batch_shape` is the number of
466+
sampled models.
467+
Y: A `batch_shape x num_samples x 1`-dim Tensor, where `d` is
468+
the dimension of the feature space and `batch_shape` is the number of
469+
sampled models.
470+
471+
Returns:
472+
BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on
473+
given observations. The returned model has `batch_shape` copies of the
474+
training data in case of identical observations (and `batch_shape`
475+
training datasets otherwise).
476+
"""
477+
if X.ndim == 2 and Y.ndim == 2:
478+
# To avoid an error in GPyTorch when inferring the batch dimension, we add
479+
# the explicit batch shape here. The result is that the conditioned model
480+
# will have 'batch_shape' copies of the training data.
481+
X = X.repeat(self.batch_shape + (1, 1))
482+
Y = Y.repeat(self.batch_shape + (1, 1))
483+
484+
elif X.ndim < Y.ndim:
485+
# We need to duplicate the training data to enable correct batch
486+
# size inference in gpytorch.
487+
X = X.repeat(*(Y.shape[:-2] + (1, 1)))
488+
489+
return super().condition_on_observations(X, Y, **kwargs)

botorch/optim/optimize_mixed.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import itertools
99
import random
1010
import warnings
11-
from collections.abc import Sequence
12-
from typing import Any, Callable
11+
from typing import Any, Callable, Sequence
1312

1413
import torch
1514
from botorch.acquisition import AcquisitionFunction
@@ -574,6 +573,7 @@ def generate_starting_points(
574573
X_baseline=X_baseline,
575574
cont_dims=cont_dims,
576575
discrete_dims=discrete_dims,
576+
cat_dims=cat_dims,
577577
bounds=bounds,
578578
num_spray_points=num_spray_points,
579579
std_cont_perturbation=assert_is_instance(
@@ -598,6 +598,7 @@ def generate_starting_points(
598598
new_x_init = sample_feasible_points(
599599
opt_inputs=opt_inputs,
600600
discrete_dims=discrete_dims,
601+
cat_dims=cat_dims,
601602
num_points=num_restarts - len(x_init_candts),
602603
)
603604
x_init_candts = torch.cat([x_init_candts, new_x_init], dim=0)
@@ -817,19 +818,19 @@ def optimize_acqf_mixed_alternating(
817818
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
818819
) -> tuple[Tensor, Tensor]:
819820
r"""
820-
Optimizes acquisition function over mixed binary and continuous input spaces.
821-
Multiple random restarting starting points are picked by evaluating a large set
822-
of initial candidates. From each starting point, alternating discrete local search
823-
and continuous optimization via (L-BFGS) is performed for a fixed number of
824-
iterations.
821+
Optimizes acquisition function over mixed integer, categorical, and continuous
822+
input spaces. Multiple random restarting starting points are picked by evaluating
823+
a large set of initial candidates. From each starting point, alternating
824+
discrete/categorical local search and continuous optimization via (L-BFGS)
825+
is performed for a fixed number of iterations.
825826
826827
NOTE: This method assumes that all categorical variables are
827828
integer valued.
828829
The discrete dimensions that have more than
829830
`options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will
830831
be optimized using continuous relaxation.
831-
832-
# TODO: Support categorical variables.
832+
The categorical dimensions that have more than `MAX_DISCRETE_VALUES` values
833+
be optimized by selecting random subsamples of the possible values.
833834
834835
Args:
835836
acq_function: BoTorch Acquisition function.
@@ -982,14 +983,14 @@ def optimize_acqf_mixed_alternating(
982983
)
983984
)
984985
if not (
985-
isinstance(discrete_dims, list)
986-
and len(set(discrete_dims)) == len(discrete_dims)
987-
and min(discrete_dims) >= 0
988-
and max(discrete_dims) <= dim - 1
986+
isinstance(non_cont_dims, list)
987+
and len(set(non_cont_dims)) == len(non_cont_dims)
988+
and min(non_cont_dims) >= 0
989+
and max(non_cont_dims) <= dim - 1
989990
):
990991
raise ValueError(
991-
"`discrete_dims` must be a list with unique integers "
992-
"between 0 and num_dims - 1."
992+
"`discrete_dims` and `cat_dims` must be lists with unique, disjoint "
993+
"integers between 0 and num_dims - 1."
993994
)
994995
discrete_dims_t = torch.tensor(
995996
list(discrete_dims.keys()), dtype=torch.long, device=tkwargs["device"]

test/models/test_fully_bayesian_multitask.py

Lines changed: 125 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
from gpytorch.means import ConstantMean
5757

5858
EXPECTED_KEYS = [
59-
"latent_features",
6059
"mean_module.raw_constant",
6160
"covar_module.kernels.1.raw_var",
6261
"covar_module.kernels.1.active_dims",
@@ -112,7 +111,7 @@ def _get_data_and_model(
112111
)
113112
return train_X, train_Y, train_Yvar, model
114113

115-
def _get_unnormalized_data(self, **tkwargs):
114+
def _get_unnormalized_data(self, infer_noise: bool = False, **tkwargs):
116115
with torch.random.fork_rng():
117116
torch.manual_seed(0)
118117
train_X = torch.rand(10, 4, **tkwargs)
@@ -122,9 +121,28 @@ def _get_unnormalized_data(self, **tkwargs):
122121
)
123122
train_X = torch.cat([5 + 5 * train_X, task_indices], dim=1)
124123
test_X = 5 + 5 * torch.rand(5, 4, **tkwargs)
125-
train_Yvar = 0.1 * torch.arange(10, **tkwargs).unsqueeze(-1)
124+
if infer_noise:
125+
train_Yvar = None
126+
else:
127+
train_Yvar = 0.1 * torch.arange(10, **tkwargs).unsqueeze(-1)
126128
return train_X, train_Y, train_Yvar, test_X
127129

130+
def _get_unnormalized_condition_data(
131+
self, num_models: int, num_cond: int, dim: int, infer_noise: bool, **tkwargs
132+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
133+
with torch.random.fork_rng():
134+
torch.manual_seed(0)
135+
cond_X = 5 + 5 * torch.rand(num_models, num_cond, dim, **tkwargs)
136+
cond_Y = 10 + torch.sin(cond_X[..., :1])
137+
cond_Yvar = (
138+
None if infer_noise else 0.1 * torch.ones(cond_Y.shape, **tkwargs)
139+
)
140+
# adding the task dimension
141+
cond_X = torch.cat(
142+
[cond_X, torch.zeros(num_models, num_cond, 1, **tkwargs)], dim=-1
143+
)
144+
return cond_X, cond_Y, cond_Yvar
145+
128146
def _get_mcmc_samples(self, num_samples: int, dim: int, task_rank: int, **tkwargs):
129147
mcmc_samples = {
130148
"lengthscale": torch.rand(num_samples, 1, dim, **tkwargs),
@@ -604,6 +622,110 @@ def test_acquisition_functions(self):
604622
)
605623
self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape))
606624

625+
def test_condition_on_observation(self) -> None:
626+
# The following conditioned data shapes should work (output describes):
627+
# training data shape after cond(batch shape in output is req. in gpytorch)
628+
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d
629+
# X: n x d, Y: n x d --> num_models x n x d
630+
# X: n x d, Y: num_models x n x d --> num_models x n x d
631+
num_models = 3
632+
num_cond = 2
633+
task_rank = 2
634+
for infer_noise, dtype in itertools.product(
635+
(True, False), (torch.float, torch.double)
636+
):
637+
tkwargs = {"device": self.device, "dtype": dtype}
638+
train_X, _, _, model = self._get_data_and_model(
639+
task_rank=task_rank,
640+
infer_noise=infer_noise,
641+
**tkwargs,
642+
)
643+
num_dims = train_X.shape[1] - 1
644+
mcmc_samples = self._get_mcmc_samples(
645+
num_samples=3,
646+
dim=num_dims,
647+
task_rank=task_rank,
648+
**tkwargs,
649+
)
650+
model.load_mcmc_samples(mcmc_samples)
651+
652+
num_train = train_X.shape[0]
653+
test_X = torch.rand(num_models, num_dims, **tkwargs)
654+
655+
cond_X, cond_Y, cond_Yvar = self._get_unnormalized_condition_data(
656+
num_models=num_models,
657+
num_cond=num_cond,
658+
infer_noise=infer_noise,
659+
dim=num_dims,
660+
**tkwargs,
661+
)
662+
663+
# need to forward pass before conditioning
664+
model.posterior(train_X)
665+
cond_model = model.condition_on_observations(
666+
cond_X, cond_Y, noise=cond_Yvar
667+
)
668+
posterior = cond_model.posterior(test_X)
669+
self.assertEqual(
670+
posterior.mean.shape, torch.Size([num_models, len(test_X), 2])
671+
)
672+
673+
# since the data is not equal for the conditioned points, a batch size
674+
# is added to the training data
675+
self.assertEqual(
676+
cond_model.train_inputs[0].shape,
677+
torch.Size([num_models, num_train + num_cond, num_dims + 1]),
678+
)
679+
680+
# the batch shape of the condition model is added during conditioning
681+
self.assertEqual(cond_model.batch_shape, torch.Size([num_models]))
682+
683+
# condition on identical sets of data (i.e. one set) for all models
684+
# i.e, with no batch shape. This infers the batch shape.
685+
cond_X_nobatch, cond_Y_nobatch = cond_X[0], cond_Y[0]
686+
687+
# conditioning without a batch size - the resulting conditioned model
688+
# will still have a batch size
689+
model.posterior(train_X)
690+
cond_model = model.condition_on_observations(
691+
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
692+
)
693+
self.assertEqual(
694+
cond_model.train_inputs[0].shape,
695+
torch.Size([num_models, num_train + num_cond, num_dims + 1]),
696+
)
697+
698+
# With batch size only on Y.
699+
cond_model = model.condition_on_observations(
700+
cond_X_nobatch, cond_Y, noise=cond_Yvar
701+
)
702+
self.assertEqual(
703+
cond_model.train_inputs[0].shape,
704+
torch.Size([num_models, num_train + num_cond, num_dims + 1]),
705+
)
706+
707+
# test repeated conditioning
708+
repeat_cond_X = cond_X.clone()
709+
repeat_cond_X[..., 0:-1] += 2
710+
repeat_cond_model = cond_model.condition_on_observations(
711+
repeat_cond_X, cond_Y, noise=cond_Yvar
712+
)
713+
self.assertEqual(
714+
repeat_cond_model.train_inputs[0].shape,
715+
torch.Size([num_models, num_train + 2 * num_cond, num_dims + 1]),
716+
)
717+
718+
# test repeated conditioning without a batch size
719+
repeat_cond_X_nobatch = cond_X_nobatch.clone()
720+
repeat_cond_X_nobatch[..., 0:-1] += 2
721+
repeat_cond_model2 = repeat_cond_model.condition_on_observations(
722+
repeat_cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
723+
)
724+
self.assertEqual(
725+
repeat_cond_model2.train_inputs[0].shape,
726+
torch.Size([num_models, num_train + 3 * num_cond, num_dims + 1]),
727+
)
728+
607729
def test_load_samples(self):
608730
for task_rank, dtype, use_outcome_transform in itertools.product(
609731
[1, 2], [torch.float, torch.double], (False, True)
@@ -671,18 +793,6 @@ def test_load_samples(self):
671793
train_Yvar_tf.clamp(MIN_INFERRED_NOISE_LEVEL),
672794
)
673795
)
674-
self.assertTrue(
675-
torch.allclose(
676-
model.task_covar_module.lengthscale,
677-
mcmc_samples["task_lengthscale"],
678-
)
679-
)
680-
self.assertTrue(
681-
torch.allclose(
682-
model.latent_features,
683-
mcmc_samples["latent_features"],
684-
)
685-
)
686796

687797
def test_construct_inputs(self):
688798
for dtype, infer_noise in [(torch.float, False), (torch.double, True)]:

0 commit comments

Comments
 (0)