1111
1212from __future__ import annotations
1313
14- import warnings
1514from abc import ABC , abstractmethod
1615from collections .abc import Callable
1716
2120 IdentityMCObjective ,
2221 MCAcquisitionObjective ,
2322)
24- from botorch .exceptions .warnings import CostAwareWarning
2523from botorch .models .deterministic import DeterministicModel
2624from botorch .models .gpytorch import GPyTorchModel
2725from botorch .sampling .base import MCSampler
@@ -112,7 +110,7 @@ def __init__(
112110 cost_model : DeterministicModel | GPyTorchModel ,
113111 use_mean : bool = True ,
114112 cost_objective : MCAcquisitionObjective | None = None ,
115- min_cost : float = 1e-2 ,
113+ log : bool = False ,
116114 ) -> None :
117115 r"""Cost-aware utility that weights increase in utility by inverse cost.
118116 For negative increases in utility, the utility is instead scaled by the
@@ -130,7 +128,8 @@ def __init__(
130128 un-transform predictions/samples of a cost model fit on the
131129 log-transformed cost (often done to ensure non-negativity). If the
132130 cost model is multi-output, then by default this will sum the cost
133- across outputs.
131+ across outputs. NOTE: ``cost_objective`` must output
132+ strictly positive values; forward will raise a ``ValueError`` otherwise.
134133 min_cost: A value used to clamp the cost samples so that they are not
135134 too close to zero, which may cause numerical issues.
136135 Returns:
@@ -147,7 +146,7 @@ def __init__(
147146 self .cost_model = cost_model
148147 self .cost_objective : MCAcquisitionObjective = cost_objective
149148 self ._use_mean = use_mean
150- self ._min_cost = min_cost
149+ self ._log = log
151150
152151 def forward (
153152 self ,
@@ -202,18 +201,21 @@ def forward(
202201 cost = none_throws (sampler )(cost_posterior )
203202 cost = self .cost_objective (cost )
204203
205- # Ensure non-negativity of the cost
206- if torch .any (cost < - 1e-7 ):
207- warnings .warn (
208- "Encountered negative cost values in InverseCostWeightedUtility" ,
209- CostAwareWarning ,
210- stacklevel = 2 ,
204+ # Ensure that costs are positive
205+ if not torch .all (cost > 0.0 ):
206+ raise ValueError (
207+ "Costs must be strictly positive. Consider clamping cost_objective."
211208 )
212- # clamp (away from zero) and sum cost across elements of the q-batch -
213- # this will be of shape `num_fantasies x batch_shape` or `batch_shape`
214- cost = cost .clamp_min ( self . _min_cost ). sum (dim = - 1 )
209+
210+ # sum costs along q-batch
211+ cost = cost .sum (dim = - 1 )
215212
216213 # compute and return the ratio on the sample level - If `use_mean=True`
217214 # this operation involves broadcasting the cost across fantasies.
218- # We multiply by the cost if the deltas are <= 0, see discussion #2914
219- return torch .where (deltas > 0 , deltas / cost , deltas * cost )
215+ if self ._log :
216+ # if _log is True then input deltas are in log space
217+ # so original deltas cannot be <= 0
218+ return deltas - torch .log (cost )
219+ else :
220+ # We multiply by the cost if the deltas are <= 0, see discussion #2914
221+ return torch .where (deltas > 0 , deltas / cost , deltas * cost )
0 commit comments