Skip to content

Commit

Permalink
Add more log ACQFs (#308)
Browse files Browse the repository at this point in the history
* add mobo strategy

* update notebook

* update dependencies

* change inheritance hierarchy

* update specs

* update test

* fix it
  • Loading branch information
jduerholt authored Nov 9, 2023
1 parent 4cb522b commit d6576f6
Show file tree
Hide file tree
Showing 15 changed files with 694 additions and 71 deletions.
46 changes: 39 additions & 7 deletions bofire/data_models/acquisition_functions/acquisition_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,63 @@ class AcquisitionFunction(BaseModel):
type: str


class qNEI(AcquisitionFunction):
class SingleObjectiveAcquisitionFunction(AcquisitionFunction):
type: str


class MultiObjectiveAcquisitionFunction(AcquisitionFunction):
type: str


class qNEI(SingleObjectiveAcquisitionFunction):
type: Literal["qNEI"] = "qNEI"
prune_baseline: bool = True


class qLogNEI(AcquisitionFunction):
class qLogNEI(SingleObjectiveAcquisitionFunction):
type: Literal["qLogNEI"] = "qLogNEI"
prune_baseline: bool = True


class qEI(AcquisitionFunction):
class qEI(SingleObjectiveAcquisitionFunction):
type: Literal["qEI"] = "qEI"


class qLogEI(AcquisitionFunction):
class qLogEI(SingleObjectiveAcquisitionFunction):
type: Literal["qLogEI"] = "qLogEI"


class qSR(AcquisitionFunction):
class qSR(SingleObjectiveAcquisitionFunction):
type: Literal["qSR"] = "qSR"


class qUCB(AcquisitionFunction):
class qUCB(SingleObjectiveAcquisitionFunction):
type: Literal["qUCB"] = "qUCB"
beta: Annotated[float, Field(ge=0)] = 0.2


class qPI(AcquisitionFunction):
class qPI(SingleObjectiveAcquisitionFunction):
type: Literal["qPI"] = "qPI"
tau: PositiveFloat = 1e-3


class qEHVI(MultiObjectiveAcquisitionFunction):
type: Literal["qEHVI"] = "qEHVI"
alpha: Annotated[float, Field(ge=0)] = 0.0


class qLogEHVI(MultiObjectiveAcquisitionFunction):
type: Literal["qLogEHVI"] = "qLogEHVI"
alpha: Annotated[float, Field(ge=0)] = 0.0


class qNEHVI(MultiObjectiveAcquisitionFunction):
type: Literal["qNEHVI"] = "qNEHVI"
alpha: Annotated[float, Field(ge=0)] = 0.0
prune_baseline: bool = True


class qLogNEHVI(MultiObjectiveAcquisitionFunction):
type: Literal["qLogNEHVI"] = "qLogNEHVI"
alpha: Annotated[float, Field(ge=0)] = 0.0
prune_baseline: bool = True
22 changes: 20 additions & 2 deletions bofire/data_models/acquisition_functions/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,33 @@

from bofire.data_models.acquisition_functions.acquisition_function import (
AcquisitionFunction,
MultiObjectiveAcquisitionFunction,
SingleObjectiveAcquisitionFunction,
qEHVI,
qEI,
qLogEHVI,
qLogEI,
qLogNEHVI,
qLogNEI,
qNEHVI,
qNEI,
qPI,
qSR,
qUCB,
)

AbstractAcquisitionFunction = AcquisitionFunction
AbstractAcquisitionFunction = [
AcquisitionFunction,
SingleObjectiveAcquisitionFunction,
MultiObjectiveAcquisitionFunction,
]

AnyAcquisitionFunction = Union[
qNEI, qEI, qSR, qUCB, qPI, qLogEI, qLogNEI, qEHVI, qLogEHVI, qNEHVI, qLogNEHVI
]

AnySingleObjectiveAcquisitionFunction = Union[
qNEI, qEI, qSR, qUCB, qPI, qLogEI, qLogNEI
]

AnyAcquisitionFunction = Union[qNEI, qEI, qSR, qUCB, qPI, qLogEI, qLogNEI]
AnyMultiObjectiveAcquisitionFunction = Union[qEHVI, qLogEHVI, qNEHVI, qLogNEHVI]
3 changes: 3 additions & 0 deletions bofire/data_models/strategies/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from bofire.data_models.strategies.doe import DoEStrategy
from bofire.data_models.strategies.factorial import FactorialStrategy
from bofire.data_models.strategies.predictives.botorch import BotorchStrategy
from bofire.data_models.strategies.predictives.mobo import MoboStrategy
from bofire.data_models.strategies.predictives.multiobjective import (
MultiobjectiveStrategy,
)
Expand Down Expand Up @@ -53,6 +54,7 @@
DoEStrategy,
StepwiseStrategy,
FactorialStrategy,
MoboStrategy,
]

AnyPredictive = Union[
Expand All @@ -63,6 +65,7 @@
QehviStrategy,
QnehviStrategy,
QparegoStrategy,
MoboStrategy,
]

AnySampler = Union[PolytopeSampler, RejectionSampler]
Expand Down
76 changes: 76 additions & 0 deletions bofire/data_models/strategies/predictives/mobo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Dict, Literal, Optional, Type

from pydantic import Field, validator

from bofire.data_models.acquisition_functions.api import (
AnyMultiObjectiveAcquisitionFunction,
qLogNEHVI,
)
from bofire.data_models.features.api import CategoricalOutput, Feature
from bofire.data_models.objectives.api import (
CloseToTargetObjective,
MaximizeObjective,
MaximizeSigmoidObjective,
MinimizeObjective,
MinimizeSigmoidObjective,
Objective,
TargetObjective,
)
from bofire.data_models.strategies.predictives.multiobjective import (
MultiobjectiveStrategy,
)


class MoboStrategy(MultiobjectiveStrategy):
type: Literal["MoboStrategy"] = "MoboStrategy"
ref_point: Optional[Dict[str, float]] = None
acquisition_function: AnyMultiObjectiveAcquisitionFunction = Field(
default_factory=lambda: qLogNEHVI()
)

@validator("ref_point")
def validate_ref_point(cls, v, values):
"""Validate that the provided refpoint matches the provided domain."""
if v is None:
return v
keys = values["domain"].outputs.get_keys_by_objective(
[MaximizeObjective, MinimizeObjective]
)
if sorted(keys) != sorted(v.keys()):
raise ValueError(
f"Provided refpoint do not match the domain, expected keys: {keys}"
)
return v

@classmethod
def is_feature_implemented(cls, my_type: Type[Feature]) -> bool:
"""Method to check if a specific feature type is implemented for the strategy
Args:
my_type (Type[Feature]): Feature class
Returns:
bool: True if the feature type is valid for the strategy chosen, False otherwise
"""
if my_type not in [CategoricalOutput]:
return True
return False

@classmethod
def is_objective_implemented(cls, my_type: Type[Objective]) -> bool:
"""Method to check if a objective type is implemented for the strategy
Args:
my_type (Type[Objective]): Objective class
Returns:
bool: True if the objective type is valid for the strategy chosen, False otherwise
"""
return my_type in [
MaximizeObjective,
MinimizeObjective,
MinimizeSigmoidObjective,
MaximizeSigmoidObjective,
TargetObjective,
CloseToTargetObjective,
]
9 changes: 7 additions & 2 deletions bofire/data_models/strategies/predictives/sobo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@

from pydantic import Field, validator

from bofire.data_models.acquisition_functions.api import AnyAcquisitionFunction, qNEI
from bofire.data_models.acquisition_functions.api import (
AnySingleObjectiveAcquisitionFunction,
qLogNEI,
)
from bofire.data_models.features.api import CategoricalOutput, Feature
from bofire.data_models.objectives.api import ConstrainedObjective, Objective
from bofire.data_models.strategies.predictives.botorch import BotorchStrategy


class SoboBaseStrategy(BotorchStrategy):
acquisition_function: AnyAcquisitionFunction = Field(default_factory=lambda: qNEI())
acquisition_function: AnySingleObjectiveAcquisitionFunction = Field(
default_factory=lambda: qLogNEI()
)

@classmethod
def is_feature_implemented(cls, my_type: Type[Feature]) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions bofire/strategies/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from bofire.strategies.doe_strategy import DoEStrategy # noqa: F401
from bofire.strategies.factorial import FactorialStrategy
from bofire.strategies.predictives.botorch import BotorchStrategy # noqa: F401
from bofire.strategies.predictives.mobo import MoboStrategy
from bofire.strategies.predictives.predictive import PredictiveStrategy # noqa: F401
from bofire.strategies.predictives.qehvi import QehviStrategy # noqa: F401
from bofire.strategies.predictives.qnehvi import QnehviStrategy # noqa: F401
Expand Down Expand Up @@ -35,6 +36,7 @@
data_models.DoEStrategy: DoEStrategy,
data_models.StepwiseStrategy: StepwiseStrategy,
data_models.FactorialStrategy: FactorialStrategy,
data_models.MoboStrategy: MoboStrategy,
}


Expand Down
Loading

0 comments on commit d6576f6

Please sign in to comment.