Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8b53bd6
acquisition function wrapper (#1532)
sdaulton Feb 15, 2023
8b49e5c
Add isinstance_af
sdaulton Feb 15, 2023
7ce1389
probabilistic reparameterization (#1533)
sdaulton Feb 15, 2023
fc9541a
Merge branch 'main' into prob-reparam
TobyBoyne Jun 6, 2025
f635524
Move is_nonnegative to optim.initializers again
TobyBoyne Jun 6, 2025
88c4230
Fix `FixedFeature` feature
TobyBoyne Jun 6, 2025
49b8ede
Fix `PenalizedAcquisition`
TobyBoyne Jun 6, 2025
a99ac62
Fix patching of isinstance_af in test_initializers
TobyBoyne Jun 6, 2025
c103033
Update types to PEP 604; fix flake8 line length errors
TobyBoyne Jun 6, 2025
7dcf186
Add test for PR with binary search space
TobyBoyne Jun 11, 2025
591b7ea
Add test for PR with categorical search space
TobyBoyne Jun 11, 2025
883fb89
Compare analytic vs MC PR in test
TobyBoyne Jun 12, 2025
526b873
Merge branch 'main' into prob-reparam
TobyBoyne Jun 13, 2025
f321b1f
Fix indexing bug when enumerating all discrete options
TobyBoyne Jun 18, 2025
ca69474
Test constructing PR input transforms
TobyBoyne Jun 18, 2025
ebfc9d7
Test forward pass of PR input transform
TobyBoyne Jun 18, 2025
6857ddf
Consolidate `*PRInputTransform`s
TobyBoyne Jun 18, 2025
fab81f7
Change order of integer idxs in PR test
TobyBoyne Jun 20, 2025
5a1b0fc
Create test for categorical PR
TobyBoyne Jun 22, 2025
0632e8e
Merge branch 'main' into prob-reparam
TobyBoyne Oct 20, 2025
64de246
Add error messages for invalid shapes/indices
TobyBoyne Oct 20, 2025
9d1f18e
Extract input shape checking in PR
TobyBoyne Oct 20, 2025
a8f536d
Test factory construction of PR Input Transform
TobyBoyne Oct 20, 2025
c2febd6
Minor typo/spelling consistency fix
TobyBoyne Oct 20, 2025
2d31180
Tests for invalid shapes in forward pass of input transform
TobyBoyne Oct 20, 2025
cfd43cb
Refactor PR code to be grouped
TobyBoyne Oct 20, 2025
47c9787
Test equality between input transforms
TobyBoyne Oct 20, 2025
6428679
Test purely continuous PR raises an error
TobyBoyne Oct 20, 2025
98de155
Test `get_probs` and `get_rounding_prob`
TobyBoyne Oct 21, 2025
9cdf3f3
Improve docstrings and comments
TobyBoyne Oct 21, 2025
37b6e2b
Add sample candidates test
TobyBoyne Oct 21, 2025
bbc5c91
Test purely continuous PR
TobyBoyne Oct 21, 2025
5077c69
Add tests for fully discrete; coverage of remaining options
TobyBoyne Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions botorch/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,17 @@
qExpectedUtilityOfBestOption,
)
from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction
from botorch.acquisition.probabilistic_reparameterization import (
AnalyticProbabilisticReparameterization,
MCProbabilisticReparameterization,
)
from botorch.acquisition.proximal import ProximalAcquisitionFunction

__all__ = [
"AcquisitionFunction",
"AnalyticAcquisitionFunction",
"AnalyticExpectedUtilityOfBestOption",
"AnalyticProbabilisticReparameterization",
"ConstrainedExpectedImprovement",
"DecoupledAcquisitionFunction",
"ExpectedImprovement",
Expand All @@ -94,6 +99,7 @@
"FixedFeatureAcquisitionFunction",
"GenericCostAwareUtility",
"InverseCostWeightedUtility",
"MCProbabilisticReparameterization",
"NoisyExpectedImprovement",
"OneShotAcquisitionFunction",
"PairwiseBayesianActiveLearningByDisagreement",
Expand Down
30 changes: 10 additions & 20 deletions botorch/acquisition/fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor
from torch.nn import Module


def get_dtype_of_sequence(values: Sequence[Tensor | float]) -> torch.dtype:
Expand Down Expand Up @@ -50,8 +50,8 @@ def _is_cuda(value: Tensor | float) -> bool:
return torch.device("cuda") if any_cuda else torch.device("cpu")


class FixedFeatureAcquisitionFunction(AcquisitionFunction):
"""A wrapper around AcquisitionFunctions to fix a subset of features.
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AquisitionFunctions to fix a subset of features.

Example:
>>> model = SingleTaskGP(train_X, train_Y) # d = 5
Expand Down Expand Up @@ -86,8 +86,9 @@ def __init__(
combination of `Tensor`s and numbers which can be broadcasted
to form a tensor with trailing dimension size of `d_f`.
"""
Module.__init__(self)
self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
dtype = torch.float
device = torch.device("cpu")
self.d = d

if isinstance(values, Tensor):
Expand Down Expand Up @@ -153,24 +154,13 @@ def forward(self, X: Tensor):
X_full = self._construct_X_full(X)
return self.acq_func(X_full)

@property
def X_pending(self):
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

@X_pending.setter
def X_pending(self, X_pending: Tensor | None):
def set_X_pending(self, X_pending: Tensor | None):
r"""Sets the `X_pending` of the base acquisition function."""
if X_pending is not None:
self.acq_func.X_pending = self._construct_X_full(X_pending)
full_X_pending = self._construct_X_full(X_pending)
else:
self.acq_func.X_pending = X_pending
full_X_pending = None
self.acq_func.set_X_pending(full_X_pending)

def _construct_X_full(self, X: Tensor) -> Tensor:
r"""Constructs the full input for the base acquisition function.
Expand Down
24 changes: 5 additions & 19 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.objective import GenericMCObjective
from botorch.exceptions import UnsupportedError
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor


Expand Down Expand Up @@ -201,7 +200,7 @@ def __call__(self, X: Tensor) -> Tensor:
return super().__call__(X=X).squeeze(dim=-1).min(dim=-1).values


class PenalizedAcquisitionFunction(AcquisitionFunction):
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
r"""Single-outcome acquisition function regularized by the given penalty.

The usage is similar to:
Expand All @@ -223,29 +222,16 @@ def __init__(
penalty_func: The regularization function.
regularization_parameter: Regularization parameter used in optimization.
"""
super().__init__(model=raw_acqf.model)
self.raw_acqf = raw_acqf
AcquisitionFunction.__init__(self, model=raw_acqf.model)
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
self.penalty_func = penalty_func
self.regularization_parameter = regularization_parameter

def forward(self, X: Tensor) -> Tensor:
raw_value = self.raw_acqf(X=X)
raw_value = self.acq_func(X=X)
penalty_term = self.penalty_func(X)
return raw_value - self.regularization_parameter * penalty_term

@property
def X_pending(self) -> Tensor | None:
return self.raw_acqf.X_pending

def set_X_pending(self, X_pending: Tensor | None = None) -> None:
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
self.raw_acqf.set_X_pending(X_pending=X_pending)
else:
raise UnsupportedError(
"The raw acquisition function is Analytic and does not account "
"for X_pending yet."
)


def group_lasso_regularizer(X: Tensor, groups: list[list[int]]) -> Tensor:
r"""Computes the group lasso regularization function for the given point.
Expand Down
Loading