Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hypothesis for targets and intervals #56

Merged
merged 16 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Target enums
- `mypy` for targets and intervals
- Tests for code blocks in README and user guides
- `hypothesis` strategies and tests for targets and intervals
- De-/serialization of target subclasses via base class

### Changed
- Renamed `bounds_transform_func` target attribute to `transformation`
Expand All @@ -20,12 +22,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Wrong use of `tolerance` argument in constraints user guide
- Errors with generics and type aliases in documentation
- Deduplication bug in substance_data hypothesis strategy

### Removed
- Conda install instructions and version badge

### Deprecations
- `Interval.is_finite` replaced with `Interval.is_bounded`
- Specifying target configs without explicit type information is deprecated

## [0.7.1] - 2023-12-07
### Added
Expand Down
3 changes: 2 additions & 1 deletion baybe/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from attr import define, field
from attr.validators import deep_iterable, in_, instance_of, min_len

from baybe.targets.base import Target
from baybe.targets.numerical import NumericalTarget
from baybe.utils import SerialMixin, geom_mean

Expand Down Expand Up @@ -37,7 +38,7 @@ class Objective(SerialMixin):
mode: Literal["SINGLE", "DESIRABILITY"] = field()
"""The optimization mode."""

targets: List[NumericalTarget] = field(validator=min_len(1))
targets: List[Target] = field(validator=min_len(1))
"""The list of targets used for the objective."""

weights: List[float] = field(converter=_normalize_weights)
Expand Down
40 changes: 38 additions & 2 deletions baybe/targets/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
"""Base functionality for all BayBE targets."""

import warnings
from abc import ABC, abstractmethod

import pandas as pd
from attrs import define, field

from baybe.utils import (
SerialMixin,
converter,
get_base_structure_hook,
unstructure_base,
)


@define(frozen=True)
class Target(ABC):
class Target(ABC, SerialMixin):
"""Abstract base class for all target variables.

Stores information about the range, transformations, etc.
Expand All @@ -29,3 +36,32 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame:
Returns:
A dataframe containing the transformed data.
"""


def _add_missing_type_hook(hook):
"""Adjust the structuring hook such that it auto-fills missing target types.
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved

Used for backward compatibility only and will be removed in future versions.
"""

def added_type_hook(dict_, cls):
if "type" not in dict_:
warnings.warn(
f"The target type is not specified for target '{dict_['name']}' and "
f"thus automatically set to 'NumericalTarget'. "
f"However, omitting the target type is deprecated and will no longer "
f"be supported in future versions. "
f"Therefore, please add an explicit target type.",
DeprecationWarning,
)
dict_["type"] = "NumericalTarget"
return hook(dict_, cls)

return added_type_hook


# Register (un-)structure hooks
converter.register_structure_hook(
Target, _add_missing_type_hook(get_base_structure_hook(Target))
)
converter.register_unstructure_hook(Target, unstructure_base)
5 changes: 5 additions & 0 deletions baybe/utils/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def is_half_bounded(self) -> bool:
"""Check if the interval is half-bounded."""
return self.is_left_bounded ^ self.is_right_bounded

@property
def is_fully_unbounded(self) -> bool:
"""Check if the interval represents the entire real number line."""
return not (self.is_left_bounded or self.is_right_bounded)

@property
def is_finite(self) -> bool:
"""Check whether the interval is finite."""
Expand Down
1 change: 1 addition & 0 deletions tests/hypothesis_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Hypothesis strategies."""
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test alternative ways of creation not considered in the strategies."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Test alternative ways of creation not considered in the strategies."""

import pytest

from baybe.parameters.categorical import CategoricalParameter
from baybe.parameters.enum import CategoricalEncoding, SubstanceEncoding
from baybe.parameters.substance import SubstanceParameter


@pytest.mark.parametrize("encoding", [e.name for e in CategoricalEncoding])
def test_string_encoding_categorical_parameter(encoding):
"""The encoding can also be specified as a string instead of an enum value."""
CategoricalParameter(name="string_encoding", values=["A", "B"], encoding=encoding)


@pytest.mark.parametrize("encoding", [e.name for e in SubstanceEncoding])
def test_string_encoding_substance_parameter(encoding):
"""The encoding can also be specified as a string instead of an enum value."""
SubstanceParameter(
name="string_encoding", data={"A": "C", "B": "CC"}, encoding=encoding
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Test alternative ways of creation not considered in the strategies."""

import pytest
from pytest import param

from baybe.targets.enum import TargetMode, TargetTransformation
from baybe.targets.numerical import NumericalTarget


@pytest.mark.parametrize(
"bounds",
[
param((None, None), id="unbounded"),
param((0, 1), id="bounded"),
],
)
def test_unspecified_transformation(bounds):
"""The transformation can be unspecified, in which case a default is chosen.

This explicitly tests the logic of the attrs default method.
"""
NumericalTarget("unspecified", mode="MAX", bounds=bounds)


@pytest.mark.parametrize("mode", (m.name for m in TargetMode))
def test_string_mode(mode):
"""The mode can also be specified as a string instead of an enum value."""
NumericalTarget("string_mode", mode=mode, bounds=(0, 1))


@pytest.mark.parametrize("transformation", (t.name for t in TargetTransformation))
def test_string_transformation(transformation):
"""The transformation can also be specified as a string instead of an enum value."""
mode = "MAX" if transformation == "LINEAR" else "MATCH"
NumericalTarget(
"string_mode", mode=mode, bounds=(0, 1), transformation=transformation
)
19 changes: 19 additions & 0 deletions tests/hypothesis_strategies/alternative_creation/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Test alternative ways of creation not considered in the strategies."""

import pytest
from pytest import param

from baybe.utils.interval import Interval


@pytest.mark.parametrize(
("lower", "upper"),
[
param(None, 0, id="left-unbounded"),
param(0, None, id="right-unbounded"),
param(None, None, id="fully-unbounded"),
],
)
def test_none_bounds(lower, upper):
"""Bounds can also be None."""
Interval(lower, upper)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Hypothesis strategies."""
"""Hypothesis strategies for parameters."""

import hypothesis.strategies as st
import numpy as np
Expand All @@ -18,13 +18,7 @@
from baybe.utils.chemistry import get_canonical_smiles
from baybe.utils.numeric import DTypeFloatNumpy

_largest_lower_interval = np.nextafter(
np.nextafter(np.inf, 0, dtype=DTypeFloatNumpy), 0, dtype=DTypeFloatNumpy
)
"""
The largest possible value for the lower end of a continuous interval such that there
still exists a larger but finite number for the upper interval end.
"""
from .utils import interval

decorrelation = st.one_of(
st.booleans(),
Expand Down Expand Up @@ -55,9 +49,13 @@ def substance_data(draw: st.DrawFn):
"""Generate data for :class:`baybe.parameters.substance.SubstanceParameter`."""
names = draw(st.lists(st.text(min_size=1), min_size=2, max_size=10, unique=True))
substances = draw(
st.lists(smiles(), min_size=len(names), max_size=len(names), unique=True)
st.lists(
smiles().map(get_canonical_smiles),
min_size=len(names),
max_size=len(names),
unique=True,
)
)
substances = list(set(get_canonical_smiles(s) for s in substances))
return dict(zip(names, substances))


Expand Down Expand Up @@ -109,9 +107,8 @@ def numerical_discrete_parameter(
def numerical_continuous_parameter(draw: st.DrawFn):
"""Generate :class:`baybe.parameters.numerical.NumericalContinuousParameter`."""
name = draw(parameter_name)
lower = draw(st.floats(max_value=_largest_lower_interval, allow_infinity=False))
upper = draw(st.floats(min_value=lower, exclude_min=True, allow_infinity=False))
return NumericalContinuousParameter(name=name, bounds=(lower, upper))
bounds = draw(interval(exclude_half_bounded=True, exclude_fully_unbounded=True))
return NumericalContinuousParameter(name=name, bounds=bounds)


@st.composite
Expand Down
32 changes: 32 additions & 0 deletions tests/hypothesis_strategies/targets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Hypothesis strategies for targets."""

import hypothesis.strategies as st

from baybe.targets.enum import TargetMode
from baybe.targets.numerical import _VALID_TRANSFORMATIONS, NumericalTarget

from .utils import interval

target_name = st.text(min_size=1)
"""A strategy that generates target names."""


@st.composite
def numerical_target(draw: st.DrawFn):
"""Generate :class:`baybe.targets.numerical.NumericalTarget`."""
name = draw(target_name)
mode = draw(st.sampled_from(TargetMode))
bounds = draw(
interval(
exclude_half_bounded=True, exclude_fully_unbounded=mode is TargetMode.MATCH
)
)
transformation = draw(st.sampled_from(_VALID_TRANSFORMATIONS[mode]))

return NumericalTarget(
name=name, mode=mode, bounds=bounds, transformation=transformation
)


target = numerical_target()
"""A strategy that generates targets."""
36 changes: 36 additions & 0 deletions tests/hypothesis_strategies/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Hypothesis strategies for generating utility objects."""

import hypothesis.strategies as st
from hypothesis import assume

from baybe.utils.interval import Interval


@st.composite
def interval(
draw: st.DrawFn,
*,
exclude_bounded: bool = False,
exclude_half_bounded: bool = False,
exclude_fully_unbounded: bool = False,
):
"""Generate :class:`baybe.utils.interval.Interval`."""
assert not all(
(exclude_bounded, exclude_half_bounded, exclude_fully_unbounded)
), "At least one Interval type must be allowed."

# Create interval from ordered pair of floats
bounds = (
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved
st.tuples(st.floats(), st.floats()).map(sorted).filter(lambda x: x[0] < x[1])
)
interval = Interval.create(draw(bounds))

# Filter excluded intervals
if exclude_bounded:
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved
assume(not interval.is_bounded)
if exclude_half_bounded:
assume(not interval.is_half_bounded)
if exclude_fully_unbounded:
assume(not interval.is_fully_unbounded)

return interval
2 changes: 1 addition & 1 deletion tests/serialization/test_parameter_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from baybe.parameters.base import Parameter

from ..hypothesis_strategies import (
from ..hypothesis_strategies.parameters import (
categorical_parameter,
custom_parameter,
numerical_continuous_parameter,
Expand Down
15 changes: 15 additions & 0 deletions tests/serialization/test_target_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Test serialization of targets."""

from hypothesis import given

from baybe.targets.base import Target

from ..hypothesis_strategies.targets import target


@given(target)
def test_parameter_roundtrip(target: Target):
"""A serialization roundtrip yields an equivalent object."""
string = target.to_json()
target2 = Target.from_json(string)
assert target == target2, (target, target2)
26 changes: 20 additions & 6 deletions tests/test_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,30 @@
from baybe.searchspace import SearchSpace
from baybe.strategies import Strategy
from baybe.targets import Objective
from baybe.targets.base import Target
from baybe.utils.interval import Interval


def test_deprecated_baybe_class(parameters, objective):
"""Using the deprecated ``BayBE`` class should raise a warning."""
"""Using the deprecated ``BayBE`` class raises a warning."""
with pytest.warns(DeprecationWarning):
BayBE(SearchSpace.from_product(parameters), objective)


def test_moved_objective(targets):
"""Importing ``Objective`` from ``baybe.targets`` should raise a warning."""
"""Importing ``Objective`` from ``baybe.targets`` raises a warning."""
with pytest.warns(DeprecationWarning):
Objective(mode="SINGLE", targets=targets)


def test_renamed_surrogate():
"""Importing from ``baybe.surrogate`` should raise a warning."""
"""Importing from ``baybe.surrogate`` raises a warning."""
with pytest.warns(DeprecationWarning):
from baybe.surrogate import GaussianProcessSurrogate # noqa: F401


def test_missing_strategy_type(config):
"""Specifying a strategy without a corresponding type should trigger a warning."""
"""Specifying a strategy without a corresponding type raises a warning."""
dict_ = json.loads(config)
dict_["strategy"].pop("type")
config_without_strategy_type = json.dumps(dict_)
Expand All @@ -39,12 +40,25 @@ def test_missing_strategy_type(config):


def test_deprecated_strategy_class():
"""Using the deprecated ``Strategy`` class should raise a warning."""
"""Using the deprecated ``Strategy`` class raises a warning."""
with pytest.warns(DeprecationWarning):
Strategy()


def test_deprecated_interval_is_finite():
"""Using the deprecated ``Interval.is_finite`` property should raise a warning."""
"""Using the deprecated ``Interval.is_finite`` property raises a warning."""
with pytest.warns(DeprecationWarning):
Interval(0, 1).is_finite


def test_missing_target_type():
"""Specifying a target without a corresponding type raises a warning."""
with pytest.warns(DeprecationWarning):
Target.from_json(
json.dumps(
{
"name": "missing_type",
"mode": "MAX",
}
)
)
Loading