-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- enable more priors - add basic prior iteration tests
- Loading branch information
Showing
9 changed files
with
266 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,19 @@ | ||
"""Available priors.""" | ||
|
||
from baybe.kernels.priors.basic import GammaPrior | ||
from baybe.kernels.priors.basic import ( | ||
GammaPrior, | ||
HalfCauchyPrior, | ||
HalfNormalPrior, | ||
LogNormalPrior, | ||
NormalPrior, | ||
SmoothedBoxPrior, | ||
) | ||
|
||
__all__ = ["GammaPrior"] | ||
__all__ = [ | ||
"GammaPrior", | ||
"HalfCauchyPrior", | ||
"HalfNormalPrior", | ||
"LogNormalPrior", | ||
"NormalPrior", | ||
"SmoothedBoxPrior", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,86 @@ | ||
"""Priors that can be used for kernels.""" | ||
from typing import Any | ||
|
||
from attrs import define, field | ||
from attrs.validators import gt | ||
|
||
from baybe.kernels.priors.base import Prior | ||
from baybe.utils.validation import finite_float | ||
|
||
|
||
@define(frozen=True) | ||
class GammaPrior(Prior): | ||
"""A Gamma prior parameterized by concentration and rate.""" | ||
|
||
concentration: float = field(converter=float, validator=gt(0.0)) | ||
concentration: float = field(converter=float, validator=[finite_float, gt(0.0)]) | ||
"""The concentration.""" | ||
|
||
rate: float = field(converter=float, validator=gt(0.0)) | ||
rate: float = field(converter=float, validator=[finite_float, gt(0.0)]) | ||
"""The rate.""" | ||
|
||
|
||
@define(frozen=True) | ||
class HalfCauchyPrior(Prior): | ||
"""A Half-Cauchy prior parameterized by a scale.""" | ||
|
||
scale: float = field(converter=float, validator=[finite_float, gt(0.0)]) | ||
"""The scale.""" | ||
|
||
|
||
@define(frozen=True) | ||
class NormalPrior(Prior): | ||
"""A Normal prior parameterized by location and scale.""" | ||
|
||
loc: float = field(converter=float, validator=finite_float) | ||
"""The location (mu).""" | ||
|
||
scale: float = field(converter=float, validator=[finite_float, gt(0.0)]) | ||
"""The scale (sigma).""" | ||
|
||
|
||
@define(frozen=True) | ||
class HalfNormalPrior(Prior): | ||
"""A Half-Normal prior parameterized by a scale.""" | ||
|
||
scale: float = field(converter=float, validator=[finite_float, gt(0.0)]) | ||
"""The scale (sigma).""" | ||
|
||
|
||
@define(frozen=True) | ||
class LogNormalPrior(Prior): | ||
"""A Log-Normal prior parameterized by location and scale.""" | ||
|
||
loc: float = field(converter=float, validator=finite_float) | ||
"""The location (mu).""" | ||
|
||
scale: float = field(converter=float, validator=[finite_float, gt(0.0)]) | ||
"""The scale (sigma).""" | ||
|
||
|
||
@define(frozen=True) | ||
class SmoothedBoxPrior(Prior): | ||
"""A Smoothed-Box prior parameterized by a, b and sigma.""" | ||
|
||
a: float = field(converter=float, validator=finite_float) | ||
"""The left/lower bound.""" | ||
|
||
b: float = field(converter=float, validator=finite_float) | ||
"""The right/upper bound.""" | ||
|
||
sigma: float = field( | ||
converter=float, default=0.01, validator=[finite_float, gt(0.0)] | ||
) | ||
"""The scale.""" | ||
|
||
@b.validator | ||
def _validate_order(self, _: Any, b: float) -> None: # noqa: DOC101, DOC103 | ||
"""Validate the order of both bounds. | ||
Raises: | ||
ValueError: If b is not larger than a. | ||
""" | ||
if b <= self.a: | ||
raise ValueError( | ||
f"For {self.__class__.__name__}, the upper bound `b` (provided: {b}) " | ||
f"must be larger than the lower bound `a` (provided: {self.a})." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
"""Validation utilities.""" | ||
|
||
import math | ||
from typing import Any, Callable | ||
|
||
from attrs import Attribute | ||
|
||
|
||
def _make_restricted_float_validator( | ||
allow_nan: bool, allow_inf: bool | ||
) -> Callable[[Any, Attribute, Any], None]: | ||
"""Make an attrs-compatible validator for restricted floats. | ||
Args: | ||
allow_nan: If False, validated values cannot be 'nan'. | ||
allow_inf: If False, validated values cannot be 'inf' or '-inf'. | ||
Raises: | ||
ValueError: If no float range restriction is in place. | ||
Returns: | ||
The validator. | ||
""" | ||
if allow_nan and allow_inf: | ||
raise ValueError( | ||
"The requested validator would not restrict the float range. " | ||
"Hence, you can use `attrs.validators.instance_of(float)` instead." | ||
) | ||
|
||
def validator(self: Any, attribute: Attribute, value: Any) -> None: | ||
if not isinstance(value, float): | ||
raise ValueError( | ||
f"Values assigned to attribute '{attribute.name}' of class " | ||
f"'{self.__class__.__name__}' must be of type 'float'. " | ||
f"Given: {value} (type: {type(value)})" | ||
) | ||
if not allow_inf and math.isinf(value): | ||
raise ValueError( | ||
f"Values assigned to attribute '{attribute.name}' of class " | ||
f"'{self.__class__.__name__}' cannot be 'inf' or '-inf'." | ||
) | ||
if not allow_nan and math.isnan(value): | ||
raise ValueError( | ||
f"Values assigned to attribute '{attribute.name}' of class " | ||
f"'{self.__class__.__name__}' cannot be 'nan'." | ||
) | ||
|
||
return validator | ||
|
||
|
||
finite_float = _make_restricted_float_validator(allow_nan=False, allow_inf=False) | ||
"""Validator for finite (i.e., non-nan and non-infinite) floats.""" | ||
|
||
non_nan_float = _make_restricted_float_validator(allow_nan=False, allow_inf=True) | ||
"""Validator for non-nan floats.""" | ||
|
||
non_inf_float = _make_restricted_float_validator(allow_nan=True, allow_inf=False) | ||
"""Validator for non-infinite floats.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Strategies for basic types.""" | ||
|
||
from functools import partial | ||
|
||
import hypothesis.strategies as st | ||
|
||
finite_floats = partial(st.floats, allow_infinity=False, allow_nan=False) | ||
"""A strategy producing finite (i.e., non-nan and non-infinite) floats.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters