diff --git a/botorch/acquisition/__init__.py b/botorch/acquisition/__init__.py index 5bd208cd81..276cbd3aa0 100644 --- a/botorch/acquisition/__init__.py +++ b/botorch/acquisition/__init__.py @@ -37,6 +37,7 @@ from botorch.acquisition.logei import ( LogImprovementMCAcquisitionFunction, qLogExpectedImprovement, + qLogNoisyExpectedImprovement, ) from botorch.acquisition.max_value_entropy_search import ( MaxValueBase, @@ -96,6 +97,7 @@ "qExpectedImprovement", "LogImprovementMCAcquisitionFunction", "qLogExpectedImprovement", + "qLogNoisyExpectedImprovement", "qKnowledgeGradient", "MaxValueBase", "qMultiFidelityKnowledgeGradient", diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index cff41d46e1..c1b3fa02f2 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -47,7 +47,12 @@ qKnowledgeGradient, qMultiFidelityKnowledgeGradient, ) -from botorch.acquisition.logei import qLogExpectedImprovement +from botorch.acquisition.logei import ( + qLogExpectedImprovement, + qLogNoisyExpectedImprovement, + TAU_MAX, + TAU_RELU, +) from botorch.acquisition.max_value_entropy_search import ( qMaxValueEntropy, qMultiFidelityMaxValueEntropy, @@ -450,7 +455,7 @@ def construct_inputs_qSimpleRegret( ) -@acqf_input_constructor(qExpectedImprovement, qLogExpectedImprovement) +@acqf_input_constructor(qExpectedImprovement) def construct_inputs_qEI( model: Model, training_data: MaybeDict[SupervisedDataset], @@ -508,6 +513,72 @@ def construct_inputs_qEI( return {**base_inputs, "best_f": best_f, "constraints": constraints, "eta": eta} +@acqf_input_constructor(qLogExpectedImprovement) +def construct_inputs_qLogEI( + model: Model, + training_data: MaybeDict[SupervisedDataset], + objective: Optional[MCAcquisitionObjective] = None, + posterior_transform: Optional[PosteriorTransform] = None, + X_pending: Optional[Tensor] = None, + sampler: Optional[MCSampler] = None, + best_f: Optional[Union[float, Tensor]] = None, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + eta: Union[Tensor, float] = 1e-3, + fat: bool = True, + tau_max: float = TAU_MAX, + tau_relu: float = TAU_RELU, + **ignored: Any, +) -> Dict[str, Any]: + r"""Construct kwargs for the `qExpectedImprovement` constructor. + + Args: + model: The model to be used in the acquisition function. + training_data: Dataset(s) used to train the model. + objective: The objective to be used in the acquisition function. + posterior_transform: The posterior transform to be used in the + acquisition function. + X_pending: A `m x d`-dim Tensor of `m` design points that have been + submitted for function evaluation but have not yet been evaluated. + Concatenated into X upon forward call. + sampler: The sampler used to draw base samples. If omitted, uses + the acquisition functions's default sampler. + best_f: Threshold above (or below) which improvement is defined. + constraints: A list of constraint callables which map a Tensor of posterior + samples of dimension `sample_shape x batch-shape x q x m`-dim to a + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints + are considered satisfied if the output is less than zero. + eta: Temperature parameter(s) governing the smoothness of the sigmoid + approximation to the constraint indicators. For more details, on this + parameter, see the docs of `compute_smoothed_feasibility_indicator`. + fat: Toggles the logarithmic / linear asymptotic behavior of the smooth + approximation to the ReLU. + tau_max: Temperature parameter controlling the sharpness of the smooth + approximations to max. + tau_relu: Temperature parameter controlling the sharpness of the smooth + approximations to ReLU. + ignored: Not used. + + Returns: + A dict mapping kwarg names of the constructor to values. + """ + return { + **construct_inputs_qEI( + model=model, + training_data=training_data, + objective=objective, + posterior_transform=posterior_transform, + X_pending=X_pending, + sampler=sampler, + best_f=best_f, + constraints=constraints, + eta=eta, + ), + "fat": fat, + "tau_max": tau_max, + "tau_relu": tau_relu, + } + + @acqf_input_constructor(qNoisyExpectedImprovement) def construct_inputs_qNEI( model: Model, @@ -570,7 +641,6 @@ def construct_inputs_qNEI( assert_shared=True, first_only=True, ) - return { **base_inputs, "X_baseline": X_baseline, @@ -581,6 +651,82 @@ def construct_inputs_qNEI( } +@acqf_input_constructor(qLogNoisyExpectedImprovement) +def construct_inputs_qLogNEI( + model: Model, + training_data: MaybeDict[SupervisedDataset], + objective: Optional[MCAcquisitionObjective] = None, + posterior_transform: Optional[PosteriorTransform] = None, + X_pending: Optional[Tensor] = None, + sampler: Optional[MCSampler] = None, + X_baseline: Optional[Tensor] = None, + prune_baseline: Optional[bool] = True, + cache_root: Optional[bool] = True, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + eta: Union[Tensor, float] = 1e-3, + fat: bool = True, + tau_max: float = TAU_MAX, + tau_relu: float = TAU_RELU, + **ignored: Any, +): + r"""Construct kwargs for the `qNoisyExpectedImprovement` constructor. + + Args: + model: The model to be used in the acquisition function. + training_data: Dataset(s) used to train the model. + objective: The objective to be used in the acquisition function. + posterior_transform: The posterior transform to be used in the + acquisition function. + X_pending: A `m x d`-dim Tensor of `m` design points that have been + submitted for function evaluation but have not yet been evaluated. + Concatenated into X upon forward call. + sampler: The sampler used to draw base samples. If omitted, uses + the acquisition functions's default sampler. + X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points + that have already been observed. These points are considered as + the potential best design point. If omitted, checks that all + training_data have the same input features and take the first `X`. + prune_baseline: If True, remove points in `X_baseline` that are + highly unlikely to be the best point. This can significantly + improve performance and is generally recommended. + constraints: A list of constraint callables which map a Tensor of posterior + samples of dimension `sample_shape x batch-shape x q x m`-dim to a + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints + are considered satisfied if the output is less than zero. + eta: Temperature parameter(s) governing the smoothness of the sigmoid + approximation to the constraint indicators. For more details, on this + parameter, see the docs of `compute_smoothed_feasibility_indicator`. + fat: Toggles the logarithmic / linear asymptotic behavior of the smooth + approximation to the ReLU. + tau_max: Temperature parameter controlling the sharpness of the smooth + approximations to max. + tau_relu: Temperature parameter controlling the sharpness of the smooth + approximations to ReLU. + ignored: Not used. + + Returns: + A dict mapping kwarg names of the constructor to values. + """ + return { + **construct_inputs_qNEI( + model=model, + training_data=training_data, + objective=objective, + posterior_transform=posterior_transform, + X_pending=X_pending, + sampler=sampler, + X_baseline=X_baseline, + prune_baseline=prune_baseline, + cache_root=cache_root, + constraint=constraints, + eta=eta, + ), + "fat": fat, + "tau_max": tau_max, + "tau_relu": tau_relu, + } + + @acqf_input_constructor(qProbabilityOfImprovement) def construct_inputs_qPI( model: Model, diff --git a/botorch/acquisition/logei.py b/botorch/acquisition/logei.py index 95affe1c28..35c1562c46 100644 --- a/botorch/acquisition/logei.py +++ b/botorch/acquisition/logei.py @@ -7,15 +7,17 @@ Batch implementations of the LogEI family of improvements-based acquisition functions. """ - from __future__ import annotations from functools import partial -from typing import Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union import torch -from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction +from botorch.acquisition.monte_carlo import ( + NoisyExpectedImprovementMixin, + SampleReducingMCAcquisitionFunction, +) from botorch.acquisition.objective import ( ConstrainedMCObjective, MCAcquisitionObjective, @@ -219,6 +221,135 @@ def _sample_forward(self, obj: Tensor) -> Tensor: return li +class qLogNoisyExpectedImprovement( + LogImprovementMCAcquisitionFunction, NoisyExpectedImprovementMixin +): + r"""MC-based batch Log Noisy Expected Improvement. + + This function does not assume a `best_f` is known (which would require + noiseless observations). Instead, it uses samples from the joint posterior + over the `q` test points and previously observed points. A smooth approximation + to the canonical improvement over previously observed points is computed + for each sample and the logarithm of the average is returned. + + `qLogNEI(X) ~ log(qNEI(X)) = Log E(max(max Y - max Y_baseline, 0))`, where + `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)` + + Example: + >>> model = SingleTaskGP(train_X, train_Y) + >>> sampler = SobolQMCNormalSampler(1024) + >>> qLogNEI = qLogNoisyExpectedImprovement(model, train_X, sampler) + >>> acqval = qLogNEI(test_X) + """ + + def __init__( + self, + model: Model, + X_baseline: Tensor, + sampler: Optional[MCSampler] = None, + objective: Optional[MCAcquisitionObjective] = None, + posterior_transform: Optional[PosteriorTransform] = None, + X_pending: Optional[Tensor] = None, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + eta: Union[Tensor, float] = 1e-3, + fat: bool = True, + prune_baseline: bool = False, + cache_root: bool = True, + tau_max: float = TAU_MAX, + tau_relu: float = TAU_RELU, + **kwargs: Any, + ) -> None: + r"""q-Noisy Expected Improvement. + + Args: + model: A fitted model. + X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points + that have already been observed. These points are considered as + the potential best design point. + sampler: The sampler used to draw base samples. See `MCAcquisitionFunction` + more details. + objective: The MCAcquisitionObjective under which the samples are + evaluated. Defaults to `IdentityMCObjective()`. + posterior_transform: A PosteriorTransform (optional). + X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points + that have points that have been submitted for function evaluation + but have not yet been evaluated. Concatenated into `X` upon + forward call. Copied and set to have no gradient. + constraints: A list of constraint callables which map a Tensor of posterior + samples of dimension `sample_shape x batch-shape x q x m`-dim to a + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints + are satisfied if `constraint(samples) < 0`. + eta: Temperature parameter(s) governing the smoothness of the sigmoid + approximation to the constraint indicators. See the docs of + `compute_(log_)smoothed_constraint_indicator` for details. + fat: Toggles the logarithmic / linear asymptotic behavior of the smooth + approximation to the ReLU. + prune_baseline: If True, remove points in `X_baseline` that are + highly unlikely to be the best point. This can significantly + improve performance and is generally recommended. In order to + customize pruning parameters, instead manually call + `botorch.acquisition.utils.prune_inferior_points` on `X_baseline` + before instantiating the acquisition function. + cache_root: A boolean indicating whether to cache the root + decomposition over `X_baseline` and use low-rank updates. + tau_max: Temperature parameter controlling the sharpness of the smooth + approximations to max. + tau_relu: Temperature parameter controlling the sharpness of the smooth + approximations to ReLU. + kwargs: Here for qNEI for compatibility. + + TODO: similar to qNEHVI, when we are using sequential greedy candidate + selection, we could incorporate pending points X_baseline and compute + the incremental q(Log)NEI from the new point. This would greatly increase + efficiency for large batches. + """ + LogImprovementMCAcquisitionFunction.__init__( + self, + model=model, + sampler=sampler, + objective=objective, + posterior_transform=posterior_transform, + X_pending=X_pending, + constraints=constraints, + eta=eta, + fat=fat, + tau_max=tau_max, + ) + self.tau_relu = tau_relu + NoisyExpectedImprovementMixin.__init__( + self, + model=model, + X_baseline=X_baseline, + sampler=sampler, + objective=objective, + posterior_transform=posterior_transform, + prune_baseline=prune_baseline, + cache_root=cache_root, + **kwargs, + ) + + def _sample_forward(self, obj: Tensor) -> Tensor: + r"""Evaluate qLogNoisyExpectedImprovement per sample on the candidate set `X`. + + Args: + obj: `mc_shape x batch_shape x q`-dim Tensor of MC objective values. + + Returns: + A `sample_shape x batch_shape x q`-dim Tensor of log noisy expected smoothed + improvement values. + """ + return _log_improvement( + Y=obj, + best_f=self.compute_best_f(obj), + tau=self.tau_relu, + fat=self._fat, + ) + + def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]: + # Explicit, as both parent classes have this method, so no MRO magic required. + return NoisyExpectedImprovementMixin._get_samples_and_objectives(self, X) + + """ ###################################### utils ########################################## """ diff --git a/botorch/acquisition/monte_carlo.py b/botorch/acquisition/monte_carlo.py index 5d58f8d332..bdcb4ae052 100644 --- a/botorch/acquisition/monte_carlo.py +++ b/botorch/acquisition/monte_carlo.py @@ -402,58 +402,30 @@ def _sample_forward(self, obj: Tensor) -> Tensor: return (obj - self.best_f.unsqueeze(-1).to(obj)).clamp_min(0) -class qNoisyExpectedImprovement( - SampleReducingMCAcquisitionFunction, CachedCholeskyMCAcquisitionFunction -): - r"""MC-based batch Noisy Expected Improvement. - - This function does not assume a `best_f` is known (which would require - noiseless observations). Instead, it uses samples from the joint posterior - over the `q` test points and previously observed points. The improvement - over previously observed points is computed for each sample and averaged. - - `qNEI(X) = E(max(max Y - max Y_baseline, 0))`, where - `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)` - - Example: - >>> model = SingleTaskGP(train_X, train_Y) - >>> sampler = SobolQMCNormalSampler(1024) - >>> qNEI = qNoisyExpectedImprovement(model, train_X, sampler) - >>> qnei = qNEI(test_X) +class NoisyExpectedImprovementMixin(CachedCholeskyMCAcquisitionFunction): + """A Mixin class to share code between qNEI and qLogNEI. In particular, unifies the + 1) initialization of the baseline samples and objectives, + 2) initialization of the cached Cholesky decomposition of the kernel matrix, + 2) computation (resp. updating) of new (resp. baseline) samples and objectives, and + 4) computation of the best feasible objective. """ def __init__( self, - model: Model, X_baseline: Tensor, - sampler: Optional[MCSampler] = None, - objective: Optional[MCAcquisitionObjective] = None, - posterior_transform: Optional[PosteriorTransform] = None, - X_pending: Optional[Tensor] = None, - prune_baseline: bool = True, + prune_baseline: bool = False, cache_root: bool = True, - constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, - eta: Union[Tensor, float] = 1e-3, **kwargs: Any, ) -> None: - r"""q-Noisy Expected Improvement. + r"""Noisy Expected Improvement Mixin. + + NOTE: A pre-requisite to calling this `__init__` method is having executed the + constructor of `SampleReducingMCAcquisitionFunctions` with the object. Args: - model: A fitted model. X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points that have already been observed. These points are considered as the potential best design point. - sampler: The sampler used to draw base samples. See `MCAcquisitionFunction` - more details. - objective: The MCAcquisitionObjective under which the samples are - evaluated. Defaults to `IdentityMCObjective()`. - NOTE: `ConstrainedMCObjective` for outcome constraints is deprecated in - favor of passing the `constraints` directly to this constructor. - posterior_transform: A PosteriorTransform (optional). - X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points - that have points that have been submitted for function evaluation - but have not yet been evaluated. Concatenated into `X` upon - forward call. Copied and set to have no gradient. prune_baseline: If True, remove points in `X_baseline` that are highly unlikely to be the best point. This can significantly improve performance and is generally recommended. In order to @@ -462,35 +434,16 @@ def __init__( before instantiating the acquisition function. cache_root: A boolean indicating whether to cache the root decomposition over `X_baseline` and use low-rank updates. - constraints: A list of constraint callables which map a Tensor of posterior - samples of dimension `sample_shape x batch-shape x q x m`-dim to a - `sample_shape x batch-shape x q`-dim Tensor. The associated constraints - are considered satisfied if the output is less than zero. - eta: Temperature parameter(s) governing the smoothness of the sigmoid - approximation to the constraint indicators. For more details, on this - parameter, see the docs of `compute_smoothed_feasibility_indicator`. - - TODO: similar to qNEHVI, when we are using sequential greedy candidate - selection, we could incorporate pending points X_baseline and compute - the incremental qNEI from the new point. This would greatly increase - efficiency for large batches. """ - super().__init__( - model=model, - sampler=sampler, - objective=objective, - posterior_transform=posterior_transform, - X_pending=X_pending, - constraints=constraints, - eta=eta, - ) - self._setup(model=model, cache_root=cache_root) + # setup of CachedCholeskyMCAcquisitionFunction + # self.model initialized by `SampleReducingMCAcquisitionFunctions.__init__` + self._setup(model=self.model, cache_root=cache_root) if prune_baseline: X_baseline = prune_inferior_points( - model=model, + model=self.model, X=X_baseline, - objective=objective, - posterior_transform=posterior_transform, + objective=self.objective, + posterior_transform=self.posterior_transform, marginalize_dim=kwargs.get("marginalize_dim"), ) self.register_buffer("X_baseline", X_baseline) @@ -508,65 +461,26 @@ def __init__( # may be confusing to have two different caches, but this is not # trivial to change since each is needed for a different reason: # - LinearOperator caching to `posterior.mvn` allows for reuse within - # this function, which may be helpful if the same root decomposition - # is produced by the calls to `self.base_sampler` and - # `self._cache_root_decomposition`. + # this function, which may be helpful if the same root decomposition + # is produced by the calls to `self.base_sampler` and + # `self._cache_root_decomposition`. # - self._baseline_L allows a root decomposition to be persisted outside # this method. - baseline_samples = self.get_posterior_samples(posterior) - baseline_obj = self.objective(baseline_samples, X=X_baseline) + self.baseline_samples = self.get_posterior_samples(posterior) + self.baseline_obj = self.objective(self.baseline_samples, X=X_baseline) # We make a copy here because we will write an attribute `base_samples` # to `self.base_sampler.base_samples`, and we don't want to mutate # `self.sampler`. self.base_sampler = deepcopy(self.sampler) - self.baseline_samples = baseline_samples - self.baseline_obj = baseline_obj self.register_buffer( "_baseline_best_f", self._compute_best_feasible_objective( - samples=baseline_samples, obj=baseline_obj + samples=self.baseline_samples, obj=self.baseline_obj ), ) self._baseline_L = self._compute_root_decomposition(posterior=posterior) - def compute_best_f(self, obj: Tensor) -> Tensor: - """Computes the best (feasible) noisy objective value. - - Args: - obj: `sample_shape x batch_shape x q`-dim Tensor of objectives in forward. - - Returns: - A `sample_shape x batch_shape x 1`-dim Tensor of best feasible objectives. - """ - if self._cache_root: - val = self._baseline_best_f - else: - val = self._compute_best_feasible_objective( - samples=self.baseline_samples, obj=self.baseline_obj - ) - # ensuring shape, dtype, device compatibility with obj - n_sample_dims = len(self.sample_shape) - view_shape = torch.Size( - [ - *val.shape[:n_sample_dims], # sample dimensions - *(1,) * (obj.ndim - val.ndim), # pad to match obj - *val.shape[n_sample_dims:], # the rest - ] - ) - return val.view(view_shape).to(obj) - - def _sample_forward(self, obj: Tensor) -> Tensor: - r"""Evaluate qNoisyExpectedImprovement per sample on the candidate set `X`. - - Args: - obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values. - - Returns: - A `sample_shape x batch_shape x q`-dim Tensor of noisy improvement values. - """ - return (obj - self.compute_best_f(obj)).clamp_min(0) - def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]: r"""Compute samples at new points, using the cached root decomposition. @@ -578,7 +492,7 @@ def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]: samples with shape `sample_shape x batch_shape x q x m`, and `obj` is a tensor of MC objective values with shape `sample_shape x batch_shape x q`. """ - q = X.shape[-2] + n_baseline, q = self.X_baseline.shape[-2], X.shape[-2] X_full = torch.cat([match_batch_shape(self.X_baseline, X), X], dim=-2) # TODO: Implement more efficient way to compute posterior over both training and # test points in GPyTorch (https://github.com/cornellius-gp/gpytorch/issues/567) @@ -586,12 +500,11 @@ def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]: X_full, posterior_transform=self.posterior_transform ) if not self._cache_root: - samples_full = super().get_posterior_samples(posterior) - samples = samples_full[..., -q:, :] + samples_full = MCSamplerMixin.get_posterior_samples(self, posterior) obj_full = self.objective(samples_full, X=X_full) # assigning baseline buffers so `best_f` can be computed in _sample_forward - self.baseline_obj, obj = obj_full[..., :-q], obj_full[..., -q:] - self.baseline_samples = samples_full[..., :-q, :] + self.baseline_samples, samples = samples_full.split([n_baseline, q], dim=-2) + self.baseline_obj, obj = obj_full.split([n_baseline, q], dim=-1) return samples, obj # handle one-to-many input transforms @@ -603,16 +516,33 @@ def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]: obj = self.objective(samples, X=X_full[..., -q:, :]) return samples, obj - def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tensor: - r"""Computes best feasible objective value from samples. + def compute_best_f(self, obj: Tensor) -> Tensor: + """Computes the best (feasible) noisy objective value. Args: - samples: `sample_shape x batch_shape x q x m`-dim posterior samples. - obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values. + obj: `sample_shape x batch_shape x q`-dim Tensor of objectives in forward. Returns: A `sample_shape x batch_shape x 1`-dim Tensor of best feasible objectives. """ + if self._cache_root: + val = self._baseline_best_f + else: + val = self._compute_best_feasible_objective( + samples=self.baseline_samples, obj=self.baseline_obj + ) + # ensuring shape, dtype, device compatibility with obj + n_sample_dims = len(self.sample_shape) + view_shape = torch.Size( + [ + *val.shape[:n_sample_dims], # sample dimensions + *(1,) * (obj.ndim - val.ndim), # pad to match obj + *val.shape[n_sample_dims:], # the rest + ] + ) + return val.view(view_shape).to(obj) + + def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tensor: return compute_best_feasible_objective( samples=samples, obj=obj, @@ -624,6 +554,117 @@ def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tens ) +class qNoisyExpectedImprovement( + SampleReducingMCAcquisitionFunction, NoisyExpectedImprovementMixin +): + r"""MC-based batch Noisy Expected Improvement. + + This function does not assume a `best_f` is known (which would require + noiseless observations). Instead, it uses samples from the joint posterior + over the `q` test points and previously observed points. The improvement + over previously observed points is computed for each sample and averaged. + + `qNEI(X) = E(max(max Y - max Y_baseline, 0))`, where + `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)` + + Example: + >>> model = SingleTaskGP(train_X, train_Y) + >>> sampler = SobolQMCNormalSampler(1024) + >>> qNEI = qNoisyExpectedImprovement(model, train_X, sampler) + >>> qnei = qNEI(test_X) + """ + + def __init__( + self, + model: Model, + X_baseline: Tensor, + sampler: Optional[MCSampler] = None, + objective: Optional[MCAcquisitionObjective] = None, + posterior_transform: Optional[PosteriorTransform] = None, + X_pending: Optional[Tensor] = None, + prune_baseline: bool = True, + cache_root: bool = True, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + eta: Union[Tensor, float] = 1e-3, + **kwargs: Any, + ) -> None: + r"""q-Noisy Expected Improvement. + + Args: + model: A fitted model. + X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points + that have already been observed. These points are considered as + the potential best design point. + sampler: The sampler used to draw base samples. See `MCAcquisitionFunction` + more details. + objective: The MCAcquisitionObjective under which the samples are + evaluated. Defaults to `IdentityMCObjective()`. + NOTE: `ConstrainedMCObjective` for outcome constraints is deprecated in + favor of passing the `constraints` directly to this constructor. + posterior_transform: A PosteriorTransform (optional). + X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points + that have points that have been submitted for function evaluation + but have not yet been evaluated. Concatenated into `X` upon + forward call. Copied and set to have no gradient. + prune_baseline: If True, remove points in `X_baseline` that are + highly unlikely to be the best point. This can significantly + improve performance and is generally recommended. In order to + customize pruning parameters, instead manually call + `botorch.acquisition.utils.prune_inferior_points` on `X_baseline` + before instantiating the acquisition function. + cache_root: A boolean indicating whether to cache the root + decomposition over `X_baseline` and use low-rank updates. + constraints: A list of constraint callables which map a Tensor of posterior + samples of dimension `sample_shape x batch-shape x q x m`-dim to a + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints + are considered satisfied if the output is less than zero. + eta: Temperature parameter(s) governing the smoothness of the sigmoid + approximation to the constraint indicators. For more details, on this + parameter, see the docs of `compute_smoothed_feasibility_indicator`. + + TODO: similar to qNEHVI, when we are using sequential greedy candidate + selection, we could incorporate pending points X_baseline and compute + the incremental qNEI from the new point. This would greatly increase + efficiency for large batches. + """ + SampleReducingMCAcquisitionFunction.__init__( + self, + model=model, + sampler=sampler, + objective=objective, + posterior_transform=posterior_transform, + X_pending=X_pending, + constraints=constraints, + eta=eta, + ) + NoisyExpectedImprovementMixin.__init__( + self, + model=model, + X_baseline=X_baseline, + sampler=sampler, + objective=objective, + posterior_transform=posterior_transform, + prune_baseline=prune_baseline, + cache_root=cache_root, + **kwargs, + ) + + def _sample_forward(self, obj: Tensor) -> Tensor: + r"""Evaluate qNoisyExpectedImprovement per sample on the candidate set `X`. + + Args: + obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values. + + Returns: + A `sample_shape x batch_shape x q`-dim Tensor of noisy improvement values. + """ + return (obj - self.compute_best_f(obj)).clamp_min(0) + + def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]: + # Explicit, as both parent classes have this method, so no MRO magic required. + return NoisyExpectedImprovementMixin._get_samples_and_objectives(self, X) + + class qProbabilityOfImprovement(SampleReducingMCAcquisitionFunction): r"""MC-based batch Probability of Improvement. diff --git a/test/acquisition/test_input_constructors.py b/test/acquisition/test_input_constructors.py index e2a59d34ad..ea8f95b81e 100644 --- a/test/acquisition/test_input_constructors.py +++ b/test/acquisition/test_input_constructors.py @@ -31,6 +31,12 @@ qKnowledgeGradient, qMultiFidelityKnowledgeGradient, ) +from botorch.acquisition.logei import ( + qLogExpectedImprovement, + qLogNoisyExpectedImprovement, + TAU_MAX, + TAU_RELU, +) from botorch.acquisition.max_value_entropy_search import ( qMaxValueEntropy, qMultiFidelityMaxValueEntropy, @@ -382,6 +388,23 @@ def test_construct_inputs_qEI(self): ) self.assertEqual(kwargs["best_f"], best_f_expected) + # testing qLogEI input constructor + log_constructor = get_acqf_input_constructor(qLogExpectedImprovement) + log_kwargs = log_constructor( + model=mock_model, + training_data=self.blockX_blockY, + objective=objective, + X_pending=X_pending, + best_f=best_f_expected, + ) + # includes strict superset of kwargs tested above + self.assertTrue(kwargs.items() <= log_kwargs.items()) + self.assertTrue("fat" in log_kwargs) + self.assertTrue("tau_max" in log_kwargs) + self.assertEqual(log_kwargs["tau_max"], TAU_MAX) + self.assertTrue("tau_relu" in log_kwargs) + self.assertEqual(log_kwargs["tau_relu"], TAU_RELU) + def test_construct_inputs_qNEI(self): c = get_acqf_input_constructor(qNoisyExpectedImprovement) mock_model = mock.Mock() @@ -415,6 +438,22 @@ def test_construct_inputs_qNEI(self): self.assertIsInstance(kwargs["eta"], float) self.assertTrue(kwargs["eta"] < 1) + # testing qLogNEI input constructor + log_constructor = get_acqf_input_constructor(qLogNoisyExpectedImprovement) + log_kwargs = log_constructor( + model=mock_model, + training_data=self.blockX_blockY, + X_baseline=X_baseline, + prune_baseline=False, + ) + # includes strict superset of kwargs tested above + self.assertTrue(kwargs.items() <= log_kwargs.items()) + self.assertTrue("fat" in log_kwargs) + self.assertTrue("tau_max" in log_kwargs) + self.assertEqual(log_kwargs["tau_max"], TAU_MAX) + self.assertTrue("tau_relu" in log_kwargs) + self.assertEqual(log_kwargs["tau_relu"], TAU_RELU) + def test_construct_inputs_qPI(self): c = get_acqf_input_constructor(qProbabilityOfImprovement) mock_model = mock.Mock() diff --git a/test/acquisition/test_logei.py b/test/acquisition/test_logei.py index 6d84ad671e..f6611c0368 100644 --- a/test/acquisition/test_logei.py +++ b/test/acquisition/test_logei.py @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. import warnings +from copy import deepcopy +from itertools import product +from math import pi from unittest import mock import torch @@ -12,18 +15,29 @@ from botorch.acquisition import ( LogImprovementMCAcquisitionFunction, qLogExpectedImprovement, + qLogNoisyExpectedImprovement, ) from botorch.acquisition.input_constructors import ACQF_INPUT_CONSTRUCTOR_REGISTRY -from botorch.acquisition.monte_carlo import qExpectedImprovement +from botorch.acquisition.monte_carlo import ( + qExpectedImprovement, + qNoisyExpectedImprovement, +) + from botorch.acquisition.objective import ( ConstrainedMCObjective, + GenericMCObjective, IdentityMCObjective, PosteriorTransform, + ScalarizedPosteriorTransform, ) from botorch.exceptions import BotorchWarning, UnsupportedError from botorch.exceptions.errors import BotorchError +from botorch.models import SingleTaskGP from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler +from botorch.utils.low_rank import sample_cached_cholesky from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + +from botorch.utils.transforms import standardize from torch import Tensor @@ -127,6 +141,8 @@ def test_q_log_expected_improvement(self): self.assertEqual(res.item(), expected_val) # Further away from zero, the value is numerically indistinguishable with # single precision arithmetic. + self.assertEqual(exp_log_res.dtype, dtype) + self.assertEqual(exp_log_res.device.type, self.device.type) self.assertTrue(expected_val <= exp_log_res.item()) self.assertTrue(exp_log_res.item() <= expected_val + log_acqf.tau_relu) @@ -294,3 +310,385 @@ def test_q_log_expected_improvement_batch(self): self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) # # TODO: Test different objectives (incl. constraints) + + +class TestQLogNoisyExpectedImprovement(BotorchTestCase): + def test_q_log_noisy_expected_improvement(self): + self.assertIn( + qLogNoisyExpectedImprovement, ACQF_INPUT_CONSTRUCTOR_REGISTRY.keys() + ) + for dtype in (torch.float, torch.double): + # the event shape is `b x q x t` = 1 x 2 x 1 + samples_noisy = torch.tensor([0.0, 1.0], device=self.device, dtype=dtype) + samples_noisy = samples_noisy.view(1, 2, 1) + # X_baseline is `q' x d` = 1 x 1 + X_baseline = torch.zeros(1, 1, device=self.device, dtype=dtype) + mm_noisy = MockModel(MockPosterior(samples=samples_noisy)) + # X is `q x d` = 1 x 1 + X = torch.zeros(1, 1, device=self.device, dtype=dtype) + + # basic test + sampler = IIDNormalSampler(sample_shape=torch.Size([2])) + kwargs = { + "model": mm_noisy, + "X_baseline": X_baseline, + "sampler": sampler, + "prune_baseline": False, + "cache_root": False, + } + acqf = qNoisyExpectedImprovement(**kwargs) + log_acqf = qLogNoisyExpectedImprovement(**kwargs) + + res = acqf(X) + self.assertEqual(res.item(), 1.0) + log_res = log_acqf(X) + self.assertEqual(log_res.dtype, dtype) + self.assertEqual(log_res.device.type, self.device.type) + self.assertAllClose(log_res.exp().item(), 1.0) + + # basic test + sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) + kwargs = { + "model": mm_noisy, + "X_baseline": X_baseline, + "sampler": sampler, + "prune_baseline": False, + "cache_root": False, + } + log_acqf = qLogNoisyExpectedImprovement(**kwargs) + log_res = log_acqf(X) + self.assertEqual(log_res.exp().item(), 1.0) + self.assertEqual( + log_acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]) + ) + bs = log_acqf.sampler.base_samples.clone() + log_acqf(X) + self.assertTrue(torch.equal(log_acqf.sampler.base_samples, bs)) + + # basic test, qmc + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + kwargs = { + "model": mm_noisy, + "X_baseline": X_baseline, + "sampler": sampler, + "prune_baseline": False, + "cache_root": False, + } + log_acqf = qLogNoisyExpectedImprovement(**kwargs) + log_res = log_acqf(X) + self.assertEqual(log_res.exp().item(), 1.0) + self.assertEqual( + log_acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]) + ) + bs = log_acqf.sampler.base_samples.clone() + log_acqf(X) + self.assertTrue(torch.equal(log_acqf.sampler.base_samples, bs)) + + # basic test for X_pending and warning + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + samples_noisy_pending = torch.tensor( + [1.0, 0.0, 0.0], device=self.device, dtype=dtype + ) + samples_noisy_pending = samples_noisy_pending.view(1, 3, 1) + mm_noisy_pending = MockModel(MockPosterior(samples=samples_noisy_pending)) + kwargs = { + "model": mm_noisy_pending, + "X_baseline": X_baseline, + "sampler": sampler, + "prune_baseline": False, + "cache_root": False, + } + # copy for log version + log_acqf = qLogNoisyExpectedImprovement(**kwargs) + log_acqf.set_X_pending() + self.assertIsNone(log_acqf.X_pending) + log_acqf.set_X_pending(None) + self.assertIsNone(log_acqf.X_pending) + log_acqf.set_X_pending(X) + self.assertEqual(log_acqf.X_pending, X) + log_acqf(X) + X2 = torch.zeros( + 1, 1, 1, device=self.device, dtype=dtype, requires_grad=True + ) + with warnings.catch_warnings(record=True) as ws, settings.debug(True): + log_acqf.set_X_pending(X2) + self.assertEqual(log_acqf.X_pending, X2) + self.assertEqual( + sum(issubclass(w.category, BotorchWarning) for w in ws), 1 + ) + + def test_q_noisy_expected_improvement_batch(self): + for dtype in (torch.float, torch.double): + # the event shape is `b x q x t` = 2 x 3 x 1 + samples_noisy = torch.zeros(2, 3, 1, device=self.device, dtype=dtype) + samples_noisy[0, -1, 0] = 1.0 + mm_noisy = MockModel(MockPosterior(samples=samples_noisy)) + # X is `q x d` = 1 x 1 + X = torch.zeros(2, 2, 1, device=self.device, dtype=dtype) + X_baseline = torch.zeros(1, 1, device=self.device, dtype=dtype) + + # test batch mode + sampler = IIDNormalSampler(sample_shape=torch.Size([2])) + kwargs = { + "model": mm_noisy, + "X_baseline": X_baseline, + "sampler": sampler, + "prune_baseline": False, + "cache_root": False, + } + acqf = qLogNoisyExpectedImprovement(**kwargs) + res = acqf(X).exp() + expected_res = torch.tensor([1.0, 0.0], dtype=dtype, device=self.device) + self.assertAllClose(res, expected_res, atol=acqf.tau_relu) + self.assertGreater(res[1].item(), 0.0) + self.assertGreater(acqf.tau_relu, res[1].item()) + + # test batch mode + sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) + acqf = qLogNoisyExpectedImprovement( + model=mm_noisy, + X_baseline=X_baseline, + sampler=sampler, + prune_baseline=False, + cache_root=False, + ) + res = acqf(X).exp() # 1-dim batch + expected_res = torch.tensor([1.0, 0.0], dtype=dtype, device=self.device) + self.assertAllClose(res, expected_res, atol=acqf.tau_relu) + self.assertGreater(res[1].item(), 0.0) + self.assertGreater(acqf.tau_relu, res[1].item()) + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 3, 1])) + bs = acqf.sampler.base_samples.clone() + acqf(X) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + res = acqf(X.expand(2, 2, 1)).exp() # 2-dim batch + expected_res = torch.tensor([1.0, 0.0], dtype=dtype, device=self.device) + self.assertAllClose(res, expected_res, atol=acqf.tau_relu) + self.assertGreater(res[1].item(), 0.0) + self.assertGreater(acqf.tau_relu, res[1].item()) + # the base samples should have the batch dim collapsed + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 3, 1])) + bs = acqf.sampler.base_samples.clone() + acqf(X.expand(2, 2, 1)) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + + # test batch mode, qmc + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + acqf = qLogNoisyExpectedImprovement( + model=mm_noisy, + X_baseline=X_baseline, + sampler=sampler, + prune_baseline=False, + cache_root=False, + ) + res = acqf(X).exp() + expected_res = torch.tensor([1.0, 0.0], dtype=dtype, device=self.device) + self.assertAllClose(res, expected_res, atol=acqf.tau_relu) + self.assertGreater(res[1].item(), 0.0) + self.assertGreater(acqf.tau_relu, res[1].item()) + self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 3, 1])) + bs = acqf.sampler.base_samples.clone() + acqf(X) + self.assertTrue(torch.equal(acqf.sampler.base_samples, bs)) + + def test_prune_baseline(self): + no = "botorch.utils.testing.MockModel.num_outputs" + prune = "botorch.acquisition.monte_carlo.prune_inferior_points" + for dtype in (torch.float, torch.double): + X_baseline = torch.zeros(1, 1, device=self.device, dtype=dtype) + X_pruned = torch.rand(1, 1, device=self.device, dtype=dtype) + with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs: + mock_num_outputs.return_value = 1 + mm = MockModel(MockPosterior(samples=X_baseline)) + with mock.patch(prune, return_value=X_pruned) as mock_prune: + acqf = qLogNoisyExpectedImprovement( + model=mm, + X_baseline=X_baseline, + prune_baseline=True, + cache_root=False, + ) + mock_prune.assert_called_once() + self.assertTrue(torch.equal(acqf.X_baseline, X_pruned)) + with mock.patch(prune, return_value=X_pruned) as mock_prune: + acqf = qLogNoisyExpectedImprovement( + model=mm, + X_baseline=X_baseline, + prune_baseline=True, + marginalize_dim=-3, + cache_root=False, + ) + _, kwargs = mock_prune.call_args + self.assertEqual(kwargs["marginalize_dim"], -3) + + def test_cache_root(self): + sample_cached_path = ( + "botorch.acquisition.cached_cholesky.sample_cached_cholesky" + ) + raw_state_dict = { + "likelihood.noise_covar.raw_noise": torch.tensor( + [[0.0895], [0.2594]], dtype=torch.float64 + ), + "mean_module.raw_constant": torch.tensor( + [-0.4545, -0.1285], dtype=torch.float64 + ), + "covar_module.raw_outputscale": torch.tensor( + [1.4876, 1.4897], dtype=torch.float64 + ), + "covar_module.base_kernel.raw_lengthscale": torch.tensor( + [[[-0.7202, -0.2868]], [[-0.8794, -1.2877]]], dtype=torch.float64 + ), + } + # test batched models (e.g. for MCMC) + for train_batch_shape, m, dtype in product( + (torch.Size([]), torch.Size([3])), (1, 2), (torch.float, torch.double) + ): + state_dict = deepcopy(raw_state_dict) + for k, v in state_dict.items(): + if m == 1: + v = v[0] + if len(train_batch_shape) > 0: + v = v.unsqueeze(0).expand(*train_batch_shape, *v.shape) + state_dict[k] = v + tkwargs = {"device": self.device, "dtype": dtype} + if m == 2: + objective = GenericMCObjective(lambda Y, X: Y.sum(dim=-1)) + else: + objective = None + for k, v in state_dict.items(): + state_dict[k] = v.to(**tkwargs) + all_close_kwargs = ( + { + "atol": 1e-1, + "rtol": 0.0, + } + if dtype == torch.float + else {"atol": 1e-4, "rtol": 0.0} + ) + torch.manual_seed(1234) + train_X = torch.rand(*train_batch_shape, 3, 2, **tkwargs) + train_Y = ( + torch.sin(train_X * 2 * pi) + + torch.randn(*train_batch_shape, 3, 2, **tkwargs) + )[..., :m] + train_Y = standardize(train_Y) + model = SingleTaskGP( + train_X, + train_Y, + ) + if len(train_batch_shape) > 0: + X_baseline = train_X[0] + else: + X_baseline = train_X + model.load_state_dict(state_dict, strict=False) + sampler = IIDNormalSampler(sample_shape=torch.Size([5]), seed=0) + torch.manual_seed(0) + acqf = qLogNoisyExpectedImprovement( + model=model, + X_baseline=X_baseline, + sampler=sampler, + objective=objective, + prune_baseline=False, + cache_root=True, + ) + + orig_base_samples = acqf.base_sampler.base_samples.detach().clone() + sampler2 = IIDNormalSampler(sample_shape=torch.Size([5]), seed=0) + sampler2.base_samples = orig_base_samples + torch.manual_seed(0) + acqf_no_cache = qLogNoisyExpectedImprovement( + model=model, + X_baseline=X_baseline, + sampler=sampler2, + objective=objective, + prune_baseline=False, + cache_root=False, + ) + for q, batch_shape in product( + (1, 3), (torch.Size([]), torch.Size([3]), torch.Size([4, 3])) + ): + acqf.q_in = -1 + acqf_no_cache.q_in = -1 + test_X = ( + 0.3 + 0.05 * torch.randn(*batch_shape, q, 2, **tkwargs) + ).requires_grad_(True) + with mock.patch( + sample_cached_path, wraps=sample_cached_cholesky + ) as mock_sample_cached: + torch.manual_seed(0) + val = acqf(test_X).exp() + mock_sample_cached.assert_called_once() + val.sum().backward() + base_samples = acqf.sampler.base_samples.detach().clone() + X_grad = test_X.grad.clone() + test_X2 = test_X.detach().clone().requires_grad_(True) + acqf_no_cache.sampler.base_samples = base_samples + with mock.patch( + sample_cached_path, wraps=sample_cached_cholesky + ) as mock_sample_cached: + torch.manual_seed(0) + val2 = acqf_no_cache(test_X2).exp() + mock_sample_cached.assert_not_called() + self.assertAllClose(val, val2, **all_close_kwargs) + val2.sum().backward() + self.assertAllClose(X_grad, test_X2.grad, **all_close_kwargs) + # test we fall back to standard sampling for + # ill-conditioned covariances + acqf._baseline_L = torch.zeros_like(acqf._baseline_L) + with warnings.catch_warnings(record=True) as ws, settings.debug(True): + with torch.no_grad(): + acqf(test_X) + self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1) + + # test w/ posterior transform + X_baseline = torch.rand(2, 1) + model = SingleTaskGP(X_baseline, torch.randn(2, 1)) + pt = ScalarizedPosteriorTransform(weights=torch.tensor([-1])) + with mock.patch.object( + qLogNoisyExpectedImprovement, + "_compute_root_decomposition", + ) as mock_cache_root: + acqf = qLogNoisyExpectedImprovement( + model=model, + X_baseline=X_baseline, + sampler=IIDNormalSampler(sample_shape=torch.Size([1])), + posterior_transform=pt, + prune_baseline=False, + cache_root=True, + ) + tf_post = model.posterior(X_baseline, posterior_transform=pt) + self.assertTrue( + torch.allclose( + tf_post.mean, mock_cache_root.call_args[-1]["posterior"].mean + ) + ) + + # testing constraints + n, d, m = 8, 1, 3 + X_baseline = torch.rand(n, d) + model = SingleTaskGP(X_baseline, torch.randn(n, m)) # batched model + nei_args = { + "model": model, + "X_baseline": X_baseline, + "prune_baseline": False, + "cache_root": True, + "posterior_transform": ScalarizedPosteriorTransform(weights=torch.ones(m)), + "sampler": SobolQMCNormalSampler(torch.Size([5])), + } + acqf = qLogNoisyExpectedImprovement(**nei_args) + X = torch.randn_like(X_baseline) + for con in [feasible_con, infeasible_con]: + with self.subTest(con=con): + target = "botorch.acquisition.utils.get_infeasible_cost" + infcost = torch.tensor([3], device=self.device, dtype=dtype) + with mock.patch(target, return_value=infcost): + cacqf = qLogNoisyExpectedImprovement(**nei_args, constraints=[con]) + + _, obj = cacqf._get_samples_and_objectives(X) + best_feas_f = cacqf.compute_best_f(obj) + if con is feasible_con: + self.assertAllClose(best_feas_f, acqf.compute_best_f(obj)) + else: + self.assertAllClose( + best_feas_f, torch.full_like(obj[..., [0]], -infcost.item()) + ) + # TODO: Test different objectives (incl. constraints) diff --git a/test/acquisition/test_monte_carlo.py b/test/acquisition/test_monte_carlo.py index c86fcc1e3c..9a70133eab 100644 --- a/test/acquisition/test_monte_carlo.py +++ b/test/acquisition/test_monte_carlo.py @@ -246,6 +246,7 @@ def test_q_noisy_expected_improvement(self): ) res = acqf(X) self.assertEqual(res.item(), 1.0) + self.assertEqual(sampler, acqf.sampler) # basic test sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345) @@ -556,7 +557,7 @@ def test_cache_root(self): "prune_baseline": False, "cache_root": True, "posterior_transform": ScalarizedPosteriorTransform(weights=torch.ones(m)), - "sampler": SobolQMCNormalSampler(5), + "sampler": SobolQMCNormalSampler(sample_shape=torch.Size([5])), } acqf = qNoisyExpectedImprovement(**nei_args) X = torch.randn_like(X_baseline) @@ -938,7 +939,6 @@ def test_mc_acquisition_function_with_constraints(self): partial(qExpectedImprovement, model=mm, best_f=0.0), # cache_root=True not supported by MockModel, see test_cache_root partial(qNoisyExpectedImprovement, cache_root=False, **nei_args), - partial(qNoisyExpectedImprovement, cache_root=True, **nei_args), ]: acqf = acqf_constructor() mm._posterior._samples = ( diff --git a/test/models/test_fully_bayesian.py b/test/models/test_fully_bayesian.py index 5994558026..52897adb54 100644 --- a/test/models/test_fully_bayesian.py +++ b/test/models/test_fully_bayesian.py @@ -18,6 +18,10 @@ ProbabilityOfImprovement, UpperConfidenceBound, ) +from botorch.acquisition.logei import ( + qLogExpectedImprovement, + qLogNoisyExpectedImprovement, +) from botorch.acquisition.monte_carlo import ( qExpectedImprovement, qNoisyExpectedImprovement, @@ -411,6 +415,7 @@ def test_acquisition_functions(self): model, warmup_steps=8, num_samples=5, thinning=2, disable_progbar=True ) deterministic = GenericDeterministicModel(f=lambda x: x[..., :1]) + # due to ModelList type, setting cache_root=False for all noisy EI variants list_gp = ModelListGP(model, model) mixed_list = ModelList(deterministic, model) simple_sampler = get_sampler( @@ -427,11 +432,23 @@ def test_acquisition_functions(self): ProbabilityOfImprovement(model=model, best_f=train_Y.max()), PosteriorMean(model=model), UpperConfidenceBound(model=model, beta=4), + qLogExpectedImprovement( + model=model, best_f=train_Y.max(), sampler=simple_sampler + ), qExpectedImprovement( model=model, best_f=train_Y.max(), sampler=simple_sampler ), + qLogNoisyExpectedImprovement( + model=model, + X_baseline=train_X, + sampler=simple_sampler, + cache_root=False, + ), qNoisyExpectedImprovement( - model=model, X_baseline=train_X, sampler=simple_sampler + model=model, + X_baseline=train_X, + sampler=simple_sampler, + cache_root=False, ), qProbabilityOfImprovement( model=model, best_f=train_Y.max(), sampler=simple_sampler @@ -443,6 +460,7 @@ def test_acquisition_functions(self): X_baseline=train_X, ref_point=torch.zeros(2, **tkwargs), sampler=list_gp_sampler, + cache_root=False, ), qExpectedHypervolumeImprovement( model=list_gp, @@ -458,6 +476,7 @@ def test_acquisition_functions(self): X_baseline=train_X, ref_point=torch.zeros(2, **tkwargs), sampler=mixed_list_sampler, + cache_root=False, ), qExpectedHypervolumeImprovement( model=mixed_list,