Skip to content

Commit

Permalink
Merge: More Priors (#225)
Browse files Browse the repository at this point in the history
- enable more priors
- add basic prior iteration tests
  • Loading branch information
Scienfitz authored May 6, 2024
2 parents ceb4a32 + 53e05c8 commit cd733a4
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 17 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives, priors
and acquisition functions
- New acquisition functions: `qSR`, `qNEI`, `LogEI`, `qLogEI`, `qLogNEI`
- `GammaPrior` can now be chosen as lengthscale prior
- Serialization user guide
- Basic deserialization tests using different class type specifiers
- `GammaPrior`, `HalfCauchyPrior`, `NormalPrior`, `HalfNormalPrior`, `LogNormalPrior`
and `SmoothedBoxPrior` can now be chosen as lengthscale prior

### Changed
- Reorganized acquisition.py into `acquisition` subpackage
Expand Down
18 changes: 16 additions & 2 deletions baybe/kernels/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
"""Available priors."""

from baybe.kernels.priors.basic import GammaPrior
from baybe.kernels.priors.basic import (
GammaPrior,
HalfCauchyPrior,
HalfNormalPrior,
LogNormalPrior,
NormalPrior,
SmoothedBoxPrior,
)

__all__ = ["GammaPrior"]
__all__ = [
"GammaPrior",
"HalfCauchyPrior",
"HalfNormalPrior",
"LogNormalPrior",
"NormalPrior",
"SmoothedBoxPrior",
]
74 changes: 72 additions & 2 deletions baybe/kernels/priors/basic.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,86 @@
"""Priors that can be used for kernels."""
from typing import Any

from attrs import define, field
from attrs.validators import gt

from baybe.kernels.priors.base import Prior
from baybe.utils.validation import finite_float


@define(frozen=True)
class GammaPrior(Prior):
"""A Gamma prior parameterized by concentration and rate."""

concentration: float = field(converter=float, validator=gt(0.0))
concentration: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The concentration."""

rate: float = field(converter=float, validator=gt(0.0))
rate: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The rate."""


@define(frozen=True)
class HalfCauchyPrior(Prior):
"""A Half-Cauchy prior parameterized by a scale."""

scale: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The scale."""


@define(frozen=True)
class NormalPrior(Prior):
"""A Normal prior parameterized by location and scale."""

loc: float = field(converter=float, validator=finite_float)
"""The location (mu)."""

scale: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The scale (sigma)."""


@define(frozen=True)
class HalfNormalPrior(Prior):
"""A Half-Normal prior parameterized by a scale."""

scale: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The scale (sigma)."""


@define(frozen=True)
class LogNormalPrior(Prior):
"""A Log-Normal prior parameterized by location and scale."""

loc: float = field(converter=float, validator=finite_float)
"""The location (mu)."""

scale: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The scale (sigma)."""


@define(frozen=True)
class SmoothedBoxPrior(Prior):
"""A Smoothed-Box prior parameterized by a, b and sigma."""

a: float = field(converter=float, validator=finite_float)
"""The left/lower bound."""

b: float = field(converter=float, validator=finite_float)
"""The right/upper bound."""

sigma: float = field(
converter=float, default=0.01, validator=[finite_float, gt(0.0)]
)
"""The scale."""

@b.validator
def _validate_order(self, _: Any, b: float) -> None: # noqa: DOC101, DOC103
"""Validate the order of both bounds.
Raises:
ValueError: If b is not larger than a.
"""
if b <= self.a:
raise ValueError(
f"For {self.__class__.__name__}, the upper bound `b` (provided: {b}) "
f"must be larger than the lower bound `a` (provided: {self.a})."
)
58 changes: 58 additions & 0 deletions baybe/utils/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Validation utilities."""

import math
from typing import Any, Callable

from attrs import Attribute


def _make_restricted_float_validator(
allow_nan: bool, allow_inf: bool
) -> Callable[[Any, Attribute, Any], None]:
"""Make an attrs-compatible validator for restricted floats.
Args:
allow_nan: If False, validated values cannot be 'nan'.
allow_inf: If False, validated values cannot be 'inf' or '-inf'.
Raises:
ValueError: If no float range restriction is in place.
Returns:
The validator.
"""
if allow_nan and allow_inf:
raise ValueError(
"The requested validator would not restrict the float range. "
"Hence, you can use `attrs.validators.instance_of(float)` instead."
)

def validator(self: Any, attribute: Attribute, value: Any) -> None:
if not isinstance(value, float):
raise ValueError(
f"Values assigned to attribute '{attribute.name}' of class "
f"'{self.__class__.__name__}' must be of type 'float'. "
f"Given: {value} (type: {type(value)})"
)
if not allow_inf and math.isinf(value):
raise ValueError(
f"Values assigned to attribute '{attribute.name}' of class "
f"'{self.__class__.__name__}' cannot be 'inf' or '-inf'."
)
if not allow_nan and math.isnan(value):
raise ValueError(
f"Values assigned to attribute '{attribute.name}' of class "
f"'{self.__class__.__name__}' cannot be 'nan'."
)

return validator


finite_float = _make_restricted_float_validator(allow_nan=False, allow_inf=False)
"""Validator for finite (i.e., non-nan and non-infinite) floats."""

non_nan_float = _make_restricted_float_validator(allow_nan=False, allow_inf=True)
"""Validator for non-nan floats."""

non_inf_float = _make_restricted_float_validator(allow_nan=True, allow_inf=False)
"""Validator for non-infinite floats."""
18 changes: 16 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
ThresholdCondition,
)
from baybe.exceptions import OptionalImportError
from baybe.kernels import MaternKernel
from baybe.kernels.priors import GammaPrior
from baybe.objectives.desirability import DesirabilityObjective
from baybe.objectives.single import SingleTargetObjective
from baybe.parameters import (
Expand Down Expand Up @@ -600,12 +602,24 @@ def fixture_default_acquisition_function():
return qExpectedImprovement()


@pytest.fixture(name="lengthscale_prior")
def fixture_default_lengthscale_prior():
"""The default lengthscale prior to be used if not specified differently."""
return GammaPrior(3, 1)


@pytest.fixture(name="kernel")
def fixture_default_kernel(lengthscale_prior):
"""The default kernel to be used if not specified differently."""
return MaternKernel(nu=5 / 2, lengthscale_prior=lengthscale_prior)


@pytest.fixture(name="surrogate_model")
def fixture_default_surrogate_model(request, onnx_surrogate):
def fixture_default_surrogate_model(request, onnx_surrogate, kernel):
"""The default surrogate model to be used if not specified differently."""
if hasattr(request, "param") and request.param == "onnx":
return onnx_surrogate
return GaussianProcessSurrogate()
return GaussianProcessSurrogate(kernel=kernel)


@pytest.fixture(name="initial_recommender")
Expand Down
8 changes: 8 additions & 0 deletions tests/hypothesis_strategies/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Strategies for basic types."""

from functools import partial

import hypothesis.strategies as st

finite_floats = partial(st.floats, allow_infinity=False, allow_nan=False)
"""A strategy producing finite (i.e., non-nan and non-infinite) floats."""
61 changes: 58 additions & 3 deletions tests/hypothesis_strategies/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,73 @@

import hypothesis.strategies as st

from baybe.kernels.priors import GammaPrior
from baybe.kernels.priors import (
GammaPrior,
HalfCauchyPrior,
HalfNormalPrior,
LogNormalPrior,
NormalPrior,
SmoothedBoxPrior,
)

from .basic import finite_floats
from .utils import intervals

gamma_priors = st.builds(
GammaPrior,
st.floats(min_value=0, exclude_min=True),
st.floats(min_value=0, exclude_min=True),
finite_floats(min_value=0.0, exclude_min=True),
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Gamma priors."""

half_cauchy_priors = st.builds(
HalfCauchyPrior,
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Half-Cauchy priors."""

normal_priors = st.builds(
NormalPrior,
finite_floats(),
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Normal priors."""

half_normal_priors = st.builds(
HalfNormalPrior,
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Half-Normal priors."""

log_normal_priors = st.builds(
LogNormalPrior,
finite_floats(),
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Log-Normal priors."""


@st.composite
def _smoothed_box_priors(draw: st.DrawFn):
"""A strategy that generates Smoothed-Box priors."""
interval = draw(intervals(exclude_half_bounded=True, exclude_fully_unbounded=True))
sigma = draw(
finite_floats(min_value=0.0, exclude_min=True),
)

return SmoothedBoxPrior(*interval.to_tuple(), sigma)


smoothed_box_priors = _smoothed_box_priors()

priors = st.one_of(
[
gamma_priors,
half_cauchy_priors,
half_normal_priors,
log_normal_priors,
normal_priors,
smoothed_box_priors,
]
)
"""A strategy that generates priors."""
11 changes: 5 additions & 6 deletions tests/hypothesis_strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from baybe.utils.interval import Interval

from .basic import finite_floats


class IntervalType(Enum):
"""The possible types of an interval on the real number line."""
Expand Down Expand Up @@ -38,18 +40,15 @@ def intervals(
allowed_types = [t for t, b in type_gate.items() if b]
interval_type = draw(st.sampled_from(allowed_types))

# A strategy producing finite floats
ffloats = st.floats(allow_infinity=False, allow_nan=False)

# Draw the bounds depending on the interval type
if interval_type is IntervalType.FULLY_UNBOUNDED:
bounds = (None, None)
elif interval_type is IntervalType.HALF_BOUNDED:
bounds = draw(
st.sampled_from(
[
(None, draw(ffloats)),
(draw(ffloats), None),
(None, draw(finite_floats())),
(draw(finite_floats()), None),
]
)
)
Expand All @@ -58,7 +57,7 @@ def intervals(
hnp.arrays(
dtype=float,
shape=(2,),
elements=ffloats,
elements=finite_floats(),
unique=True,
).map(sorted)
)
Expand Down
32 changes: 31 additions & 1 deletion tests/test_iterations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
import pytest

from baybe.acquisition.base import AcquisitionFunction
from baybe.kernels.priors import (
GammaPrior,
HalfCauchyPrior,
HalfNormalPrior,
LogNormalPrior,
NormalPrior,
SmoothedBoxPrior,
)
from baybe.recommenders.meta.base import MetaRecommender
from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender
from baybe.recommenders.naive import NaiveHybridSpaceRecommender
Expand Down Expand Up @@ -113,6 +121,15 @@

valid_meta_recommenders = get_subclasses(MetaRecommender)

valid_priors = [
GammaPrior(3, 1),
HalfCauchyPrior(2),
HalfNormalPrior(2),
LogNormalPrior(1, 2),
NormalPrior(1, 2),
SmoothedBoxPrior(0, 3, 0.1),
]

test_targets = [
["Target_max"],
["Target_min"],
Expand Down Expand Up @@ -146,7 +163,20 @@ def test_iter_nonmc_acquisition_function(campaign, n_iterations, batch_size):


@pytest.mark.slow
@pytest.mark.parametrize("surrogate_model", valid_surrogate_models)
@pytest.mark.parametrize(
"lengthscale_prior", valid_priors, ids=[c.__class__ for c in valid_priors]
)
@pytest.mark.parametrize("n_iterations", [3], ids=["i3"])
def test_iter_prior(campaign, n_iterations, batch_size):
run_iterations(campaign, n_iterations, batch_size)


@pytest.mark.slow
@pytest.mark.parametrize(
"surrogate_model",
valid_surrogate_models,
ids=[c.__class__ for c in valid_surrogate_models],
)
def test_iter_surrogate_model(campaign, n_iterations, batch_size):
run_iterations(campaign, n_iterations, batch_size)

Expand Down

0 comments on commit cd733a4

Please sign in to comment.