Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Desirability functions as objectives #497

Merged
merged 31 commits into from
Jan 21, 2025
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
74797a5
initial commit of numerical objectives. Needs to be adjusted, tested …
LukasHebing Jan 9, 2025
9c13ac7
moved torch functions to torch_tools.py
LukasHebing Jan 10, 2025
0cabd72
removed torch dependencies from data-model
LukasHebing Jan 10, 2025
92d7a93
added validators and tests for desirability data-models
LukasHebing Jan 10, 2025
74a9757
after hooks
LukasHebing Jan 10, 2025
cce5ee1
added test for "get_objective_callable"
LukasHebing Jan 10, 2025
9ce1a98
after hooks
LukasHebing Jan 10, 2025
364e3d0
added tutorial notebook desirability_objectives.ipynb
LukasHebing Jan 10, 2025
ca00d4f
after hooks
LukasHebing Jan 10, 2025
945cccf
added to AnyRealObjective
LukasHebing Jan 13, 2025
da33ff7
after hooks
LukasHebing Jan 13, 2025
20e277a
Merge remote-tracking branch 'origin/main' into feature/desirability_…
LukasHebing Jan 14, 2025
d99a308
Merge remote-tracking branch 'origin/main' into feature/desirability_…
LukasHebing Jan 15, 2025
d448d0f
changed validators to model validators
LukasHebing Jan 16, 2025
07685d1
added type: Literals to objectives
LukasHebing Jan 16, 2025
f97bd8e
after hooks
LukasHebing Jan 16, 2025
370fbed
after hooks
LukasHebing Jan 16, 2025
74a222e
debugged new validators
LukasHebing Jan 16, 2025
6c66331
after hooks
LukasHebing Jan 16, 2025
a45eba3
fixed tests
LukasHebing Jan 16, 2025
64006ef
after hooks
LukasHebing Jan 16, 2025
a86de30
got rid of using desirability base class as actual usable class
LukasHebing Jan 16, 2025
2d7d05e
- moved clip to abstract class
LukasHebing Jan 20, 2025
a4faa89
added abstractmehtod decorator
LukasHebing Jan 20, 2025
a93e928
after hooks
LukasHebing Jan 20, 2025
d939aec
changed data model specs for tests
LukasHebing Jan 20, 2025
d1e63ea
changed bounds defs in specs to lists
LukasHebing Jan 20, 2025
6c8bc67
debugged invalid specs definition
LukasHebing Jan 20, 2025
d15640f
after hooks
LukasHebing Jan 20, 2025
1934c39
moved helper class with __call__ method to abstract desirability class
LukasHebing Jan 20, 2025
cf4b7e0
after hooks
LukasHebing Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed tests
LukasHebing committed Jan 16, 2025
commit a45eba301b37cb9c962209cbbaefd3efe6f9484b
46 changes: 27 additions & 19 deletions bofire/data_models/objectives/desirabilities.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,27 @@ class DesirabilityObjective(IdentityObjective):

type: Literal["DesirabilityObjective"] = "DesirabilityObjective" # type: ignore

@pydantic.model_validator(mode="after")
def validate_clip(self):

if not "clip" in list(self.__dict__):
return self

if self.clip:
return self

log_shapes = {
key: val
for (key, val) in self.__dict__.items()
if key.startswith("log_shape_factor")
}
for key, log_shape_ in log_shapes.items():
if log_shape_ != 0:
raise ValueError(
f"Log shape factor {key} must be zero if clip is False."
)
return self


class IncreasingDesirabilityObjective(_SeriesNumpyCallable, DesirabilityObjective):
"""An objective returning a reward the scaled identity, but trimmed at the bounds:
@@ -88,25 +109,8 @@ def call_numpy(

return y

@pydantic.model_validator(mode="after")
def validate_clip(self):
if self.clip:
return self

log_shapes = {
key: val
for (key, val) in self.__dict__.items()
if key.startswith("log_shape_factor")
}
for key, log_shape_ in log_shapes.items():
if log_shape_ != 0:
raise ValueError(
f"Log shape factor {key} must be zero if clip is False."
)
return self


class DecreasingDesirabilityObjective(IncreasingDesirabilityObjective):
class DecreasingDesirabilityObjective(_SeriesNumpyCallable, DesirabilityObjective):
"""An objective returning a reward the negative, shifted scaled identity, but trimmed at the bounds:

d = ((upper_bound - x) / (upper_bound - lower_bound))^t
@@ -131,6 +135,8 @@ class DecreasingDesirabilityObjective(IncreasingDesirabilityObjective):
"""

type: Literal["DecreasingDesirabilityObjective"] = "DecreasingDesirabilityObjective" # type: ignore
log_shape_factor: float = 0.0
clip: bool = True

def call_numpy(
self,
@@ -154,7 +160,7 @@ def call_numpy(
return y


class PeakDesirabilityObjective(IncreasingDesirabilityObjective):
class PeakDesirabilityObjective(_SeriesNumpyCallable, DesirabilityObjective):
"""
A piecewise (linear or convex/concave) objective that increases from the lower bound
to the peak position and decreases from the peak position to the upper bound.
@@ -177,6 +183,8 @@ class PeakDesirabilityObjective(IncreasingDesirabilityObjective):
"""

type: Literal["PeakDesirabilityObjective"] = "PeakDesirabilityObjective" # type: ignore
log_shape_factor: float = 0.0
clip: bool = True
log_shape_factor_decreasing: float = 0.0 # often named log_t
peak_position: float = 0.5 # often named T

3 changes: 1 addition & 2 deletions pyright_output.txt
Original file line number Diff line number Diff line change
@@ -753,7 +753,6 @@
  Type "Series | ndarray[Unknown, Unknown]" is not assignable to type "ndarray[Unknown, Unknown]"
    "Series" is not assignable to "ndarray[Unknown, Unknown]" (reportArgumentType)
/Users/gdiwt/SourceCode/bofire_global/bofire/data_models/objectives/desirabilities.py:26:38 - error: "name" is possibly unbound (reportPossiblyUnboundVariable)
/Users/gdiwt/SourceCode/bofire_global/bofire/data_models/objectives/desirabilities.py:216:83 - error: "v" is not defined (reportUndefinedVariable)
/Users/gdiwt/SourceCode/bofire_global/bofire/data_models/objectives/identity.py
/Users/gdiwt/SourceCode/bofire_global/bofire/data_models/objectives/identity.py:82:5 - error: "type" overrides symbol of same name in class "IdentityObjective"
  Variable is mutable so its type is invariant
@@ -5023,4 +5022,4 @@
  "tuple[Literal[0], Literal[1]]" is not assignable to "List[float]" (reportArgumentType)
/Users/gdiwt/SourceCode/bofire_global/tests/bofire/utils/test_torch_tools.py:1031:38 - error: Expected 3 more positional arguments (reportCallIssue)
/Users/gdiwt/SourceCode/bofire_global/tests/bofire/utils/test_torch_tools.py:1040:38 - error: Expected 3 more positional arguments (reportCallIssue)
1773 errors, 2 warnings, 0 informations
1772 errors, 2 warnings, 0 informations
1 change: 0 additions & 1 deletion tests/bofire/utils/test_torch_tools.py
Original file line number Diff line number Diff line change
@@ -113,7 +113,6 @@
CloseToTargetObjective(target_value=2.0, exponent=1.0, w=0.5),
MovingMaximizeSigmoidObjective(steepness=1, tp=-1, w=1),
# ConstantObjective(w=0.5, value=1.0),
DesirabilityObjective(),
IncreasingDesirabilityObjective(
bounds=(0, 2.5), log_shape_factor=0.0, clip=False
),