diff --git a/CHANGELOG.md b/CHANGELOG.md index 394f4b766..b21f618d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Target enums - `mypy` for targets and intervals - Tests for code blocks in README and user guides +- `hypothesis` strategies and tests for targets and intervals +- De-/serialization of target subclasses via base class ### Changed - Renamed `bounds_transform_func` target attribute to `transformation` @@ -20,12 +22,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Wrong use of `tolerance` argument in constraints user guide - Errors with generics and type aliases in documentation +- Deduplication bug in substance_data hypothesis strategy ### Removed - Conda install instructions and version badge ### Deprecations - `Interval.is_finite` replaced with `Interval.is_bounded` +- Specifying target configs without explicit type information is deprecated ## [0.7.1] - 2023-12-07 ### Added diff --git a/baybe/objective.py b/baybe/objective.py index 29d0377d8..859f292c1 100644 --- a/baybe/objective.py +++ b/baybe/objective.py @@ -10,6 +10,7 @@ from attr import define, field from attr.validators import deep_iterable, in_, instance_of, min_len +from baybe.targets.base import Target from baybe.targets.numerical import NumericalTarget from baybe.utils import SerialMixin, geom_mean @@ -37,7 +38,7 @@ class Objective(SerialMixin): mode: Literal["SINGLE", "DESIRABILITY"] = field() """The optimization mode.""" - targets: List[NumericalTarget] = field(validator=min_len(1)) + targets: List[Target] = field(validator=min_len(1)) """The list of targets used for the objective.""" weights: List[float] = field(converter=_normalize_weights) diff --git a/baybe/targets/base.py b/baybe/targets/base.py index 00e7121b3..760324f0a 100644 --- a/baybe/targets/base.py +++ b/baybe/targets/base.py @@ -1,13 +1,20 @@ """Base functionality for all BayBE targets.""" - +import warnings from abc import ABC, abstractmethod import pandas as pd from attrs import define, field +from baybe.utils import ( + SerialMixin, + converter, + get_base_structure_hook, + unstructure_base, +) + @define(frozen=True) -class Target(ABC): +class Target(ABC, SerialMixin): """Abstract base class for all target variables. Stores information about the range, transformations, etc. @@ -29,3 +36,32 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame: Returns: A dataframe containing the transformed data. """ + + +def _add_missing_type_hook(hook): + """Adjust the structuring hook such that it auto-fills missing target types. + + Used for backward compatibility only and will be removed in future versions. + """ + + def added_type_hook(dict_, cls): + if "type" not in dict_: + warnings.warn( + f"The target type is not specified for target '{dict_['name']}' and " + f"thus automatically set to 'NumericalTarget'. " + f"However, omitting the target type is deprecated and will no longer " + f"be supported in future versions. " + f"Therefore, please add an explicit target type.", + DeprecationWarning, + ) + dict_["type"] = "NumericalTarget" + return hook(dict_, cls) + + return added_type_hook + + +# Register (un-)structure hooks +converter.register_structure_hook( + Target, _add_missing_type_hook(get_base_structure_hook(Target)) +) +converter.register_unstructure_hook(Target, unstructure_base) diff --git a/baybe/utils/interval.py b/baybe/utils/interval.py index bf47b5ff9..06f7f51f7 100644 --- a/baybe/utils/interval.py +++ b/baybe/utils/interval.py @@ -77,6 +77,11 @@ def is_half_bounded(self) -> bool: """Check if the interval is half-bounded.""" return self.is_left_bounded ^ self.is_right_bounded + @property + def is_fully_unbounded(self) -> bool: + """Check if the interval represents the entire real number line.""" + return not (self.is_left_bounded or self.is_right_bounded) + @property def is_finite(self) -> bool: """Check whether the interval is finite.""" diff --git a/tests/hypothesis_strategies/__init__.py b/tests/hypothesis_strategies/__init__.py new file mode 100644 index 000000000..feb355bd1 --- /dev/null +++ b/tests/hypothesis_strategies/__init__.py @@ -0,0 +1 @@ +"""Hypothesis strategies.""" diff --git a/tests/hypothesis_strategies/alternative_creation/__init__.py b/tests/hypothesis_strategies/alternative_creation/__init__.py new file mode 100644 index 000000000..0c4be758b --- /dev/null +++ b/tests/hypothesis_strategies/alternative_creation/__init__.py @@ -0,0 +1 @@ +"""Test alternative ways of creation not considered in the strategies.""" diff --git a/tests/hypothesis_strategies/alternative_creation/test_parameters.py b/tests/hypothesis_strategies/alternative_creation/test_parameters.py new file mode 100644 index 000000000..4041ee4b3 --- /dev/null +++ b/tests/hypothesis_strategies/alternative_creation/test_parameters.py @@ -0,0 +1,21 @@ +"""Test alternative ways of creation not considered in the strategies.""" + +import pytest + +from baybe.parameters.categorical import CategoricalParameter +from baybe.parameters.enum import CategoricalEncoding, SubstanceEncoding +from baybe.parameters.substance import SubstanceParameter + + +@pytest.mark.parametrize("encoding", [e.name for e in CategoricalEncoding]) +def test_string_encoding_categorical_parameter(encoding): + """The encoding can also be specified as a string instead of an enum value.""" + CategoricalParameter(name="string_encoding", values=["A", "B"], encoding=encoding) + + +@pytest.mark.parametrize("encoding", [e.name for e in SubstanceEncoding]) +def test_string_encoding_substance_parameter(encoding): + """The encoding can also be specified as a string instead of an enum value.""" + SubstanceParameter( + name="string_encoding", data={"A": "C", "B": "CC"}, encoding=encoding + ) diff --git a/tests/hypothesis_strategies/alternative_creation/test_targets.py b/tests/hypothesis_strategies/alternative_creation/test_targets.py new file mode 100644 index 000000000..33279bd00 --- /dev/null +++ b/tests/hypothesis_strategies/alternative_creation/test_targets.py @@ -0,0 +1,37 @@ +"""Test alternative ways of creation not considered in the strategies.""" + +import pytest +from pytest import param + +from baybe.targets.enum import TargetMode, TargetTransformation +from baybe.targets.numerical import NumericalTarget + + +@pytest.mark.parametrize( + "bounds", + [ + param((None, None), id="unbounded"), + param((0, 1), id="bounded"), + ], +) +def test_unspecified_transformation(bounds): + """The transformation can be unspecified, in which case a default is chosen. + + This explicitly tests the logic of the attrs default method. + """ + NumericalTarget("unspecified", mode="MAX", bounds=bounds) + + +@pytest.mark.parametrize("mode", (m.name for m in TargetMode)) +def test_string_mode(mode): + """The mode can also be specified as a string instead of an enum value.""" + NumericalTarget("string_mode", mode=mode, bounds=(0, 1)) + + +@pytest.mark.parametrize("transformation", (t.name for t in TargetTransformation)) +def test_string_transformation(transformation): + """The transformation can also be specified as a string instead of an enum value.""" + mode = "MAX" if transformation == "LINEAR" else "MATCH" + NumericalTarget( + "string_mode", mode=mode, bounds=(0, 1), transformation=transformation + ) diff --git a/tests/hypothesis_strategies/alternative_creation/test_utils.py b/tests/hypothesis_strategies/alternative_creation/test_utils.py new file mode 100644 index 000000000..3d5b2fe0f --- /dev/null +++ b/tests/hypothesis_strategies/alternative_creation/test_utils.py @@ -0,0 +1,19 @@ +"""Test alternative ways of creation not considered in the strategies.""" + +import pytest +from pytest import param + +from baybe.utils.interval import Interval + + +@pytest.mark.parametrize( + ("lower", "upper"), + [ + param(None, 0, id="left-unbounded"), + param(0, None, id="right-unbounded"), + param(None, None, id="fully-unbounded"), + ], +) +def test_none_bounds(lower, upper): + """Bounds can also be None.""" + Interval(lower, upper) diff --git a/tests/hypothesis_strategies.py b/tests/hypothesis_strategies/parameters.py similarity index 87% rename from tests/hypothesis_strategies.py rename to tests/hypothesis_strategies/parameters.py index 9623257ed..1c6d7822e 100644 --- a/tests/hypothesis_strategies.py +++ b/tests/hypothesis_strategies/parameters.py @@ -1,4 +1,4 @@ -"""Hypothesis strategies.""" +"""Hypothesis strategies for parameters.""" import hypothesis.strategies as st import numpy as np @@ -18,13 +18,7 @@ from baybe.utils.chemistry import get_canonical_smiles from baybe.utils.numeric import DTypeFloatNumpy -_largest_lower_interval = np.nextafter( - np.nextafter(np.inf, 0, dtype=DTypeFloatNumpy), 0, dtype=DTypeFloatNumpy -) -""" -The largest possible value for the lower end of a continuous interval such that there -still exists a larger but finite number for the upper interval end. -""" +from .utils import interval decorrelation = st.one_of( st.booleans(), @@ -55,9 +49,13 @@ def substance_data(draw: st.DrawFn): """Generate data for :class:`baybe.parameters.substance.SubstanceParameter`.""" names = draw(st.lists(st.text(min_size=1), min_size=2, max_size=10, unique=True)) substances = draw( - st.lists(smiles(), min_size=len(names), max_size=len(names), unique=True) + st.lists( + smiles().map(get_canonical_smiles), + min_size=len(names), + max_size=len(names), + unique=True, + ) ) - substances = list(set(get_canonical_smiles(s) for s in substances)) return dict(zip(names, substances)) @@ -109,9 +107,8 @@ def numerical_discrete_parameter( def numerical_continuous_parameter(draw: st.DrawFn): """Generate :class:`baybe.parameters.numerical.NumericalContinuousParameter`.""" name = draw(parameter_name) - lower = draw(st.floats(max_value=_largest_lower_interval, allow_infinity=False)) - upper = draw(st.floats(min_value=lower, exclude_min=True, allow_infinity=False)) - return NumericalContinuousParameter(name=name, bounds=(lower, upper)) + bounds = draw(interval(exclude_half_bounded=True, exclude_fully_unbounded=True)) + return NumericalContinuousParameter(name=name, bounds=bounds) @st.composite diff --git a/tests/hypothesis_strategies/targets.py b/tests/hypothesis_strategies/targets.py new file mode 100644 index 000000000..bf153fb04 --- /dev/null +++ b/tests/hypothesis_strategies/targets.py @@ -0,0 +1,32 @@ +"""Hypothesis strategies for targets.""" + +import hypothesis.strategies as st + +from baybe.targets.enum import TargetMode +from baybe.targets.numerical import _VALID_TRANSFORMATIONS, NumericalTarget + +from .utils import interval + +target_name = st.text(min_size=1) +"""A strategy that generates target names.""" + + +@st.composite +def numerical_target(draw: st.DrawFn): + """Generate :class:`baybe.targets.numerical.NumericalTarget`.""" + name = draw(target_name) + mode = draw(st.sampled_from(TargetMode)) + bounds = draw( + interval( + exclude_half_bounded=True, exclude_fully_unbounded=mode is TargetMode.MATCH + ) + ) + transformation = draw(st.sampled_from(_VALID_TRANSFORMATIONS[mode])) + + return NumericalTarget( + name=name, mode=mode, bounds=bounds, transformation=transformation + ) + + +target = numerical_target() +"""A strategy that generates targets.""" diff --git a/tests/hypothesis_strategies/utils.py b/tests/hypothesis_strategies/utils.py new file mode 100644 index 000000000..54089a1ba --- /dev/null +++ b/tests/hypothesis_strategies/utils.py @@ -0,0 +1,36 @@ +"""Hypothesis strategies for generating utility objects.""" + +import hypothesis.strategies as st +from hypothesis import assume + +from baybe.utils.interval import Interval + + +@st.composite +def interval( + draw: st.DrawFn, + *, + exclude_bounded: bool = False, + exclude_half_bounded: bool = False, + exclude_fully_unbounded: bool = False, +): + """Generate :class:`baybe.utils.interval.Interval`.""" + assert not all( + (exclude_bounded, exclude_half_bounded, exclude_fully_unbounded) + ), "At least one Interval type must be allowed." + + # Create interval from ordered pair of floats + bounds = ( + st.tuples(st.floats(), st.floats()).map(sorted).filter(lambda x: x[0] < x[1]) + ) + interval = Interval.create(draw(bounds)) + + # Filter excluded intervals + if exclude_bounded: + assume(not interval.is_bounded) + if exclude_half_bounded: + assume(not interval.is_half_bounded) + if exclude_fully_unbounded: + assume(not interval.is_fully_unbounded) + + return interval diff --git a/tests/serialization/test_parameter_serialization.py b/tests/serialization/test_parameter_serialization.py index fb1ccf264..e732f1d08 100644 --- a/tests/serialization/test_parameter_serialization.py +++ b/tests/serialization/test_parameter_serialization.py @@ -7,7 +7,7 @@ from baybe.parameters.base import Parameter -from ..hypothesis_strategies import ( +from ..hypothesis_strategies.parameters import ( categorical_parameter, custom_parameter, numerical_continuous_parameter, diff --git a/tests/serialization/test_target_serialization.py b/tests/serialization/test_target_serialization.py new file mode 100644 index 000000000..e30789375 --- /dev/null +++ b/tests/serialization/test_target_serialization.py @@ -0,0 +1,15 @@ +"""Test serialization of targets.""" + +from hypothesis import given + +from baybe.targets.base import Target + +from ..hypothesis_strategies.targets import target + + +@given(target) +def test_parameter_roundtrip(target: Target): + """A serialization roundtrip yields an equivalent object.""" + string = target.to_json() + target2 = Target.from_json(string) + assert target == target2, (target, target2) diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index e29422639..adc450084 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -8,29 +8,30 @@ from baybe.searchspace import SearchSpace from baybe.strategies import Strategy from baybe.targets import Objective +from baybe.targets.base import Target from baybe.utils.interval import Interval def test_deprecated_baybe_class(parameters, objective): - """Using the deprecated ``BayBE`` class should raise a warning.""" + """Using the deprecated ``BayBE`` class raises a warning.""" with pytest.warns(DeprecationWarning): BayBE(SearchSpace.from_product(parameters), objective) def test_moved_objective(targets): - """Importing ``Objective`` from ``baybe.targets`` should raise a warning.""" + """Importing ``Objective`` from ``baybe.targets`` raises a warning.""" with pytest.warns(DeprecationWarning): Objective(mode="SINGLE", targets=targets) def test_renamed_surrogate(): - """Importing from ``baybe.surrogate`` should raise a warning.""" + """Importing from ``baybe.surrogate`` raises a warning.""" with pytest.warns(DeprecationWarning): from baybe.surrogate import GaussianProcessSurrogate # noqa: F401 def test_missing_strategy_type(config): - """Specifying a strategy without a corresponding type should trigger a warning.""" + """Specifying a strategy without a corresponding type raises a warning.""" dict_ = json.loads(config) dict_["strategy"].pop("type") config_without_strategy_type = json.dumps(dict_) @@ -39,12 +40,25 @@ def test_missing_strategy_type(config): def test_deprecated_strategy_class(): - """Using the deprecated ``Strategy`` class should raise a warning.""" + """Using the deprecated ``Strategy`` class raises a warning.""" with pytest.warns(DeprecationWarning): Strategy() def test_deprecated_interval_is_finite(): - """Using the deprecated ``Interval.is_finite`` property should raise a warning.""" + """Using the deprecated ``Interval.is_finite`` property raises a warning.""" with pytest.warns(DeprecationWarning): Interval(0, 1).is_finite + + +def test_missing_target_type(): + """Specifying a target without a corresponding type raises a warning.""" + with pytest.warns(DeprecationWarning): + Target.from_json( + json.dumps( + { + "name": "missing_type", + "mode": "MAX", + } + ) + ) diff --git a/tests/test_targets.py b/tests/test_objective.py similarity index 72% rename from tests/test_targets.py rename to tests/test_objective.py index 50c4107fb..8eb489021 100644 --- a/tests/test_targets.py +++ b/tests/test_objective.py @@ -1,4 +1,4 @@ -"""Tests for the targets module.""" +"""Tests for the objective module.""" import pytest @@ -6,35 +6,6 @@ from baybe.targets import NumericalTarget -class TestInvalidTargetCreation: - """Invalid target creation raises expected error.""" - - def test_missing_bounds_for_match_mode(self): - with pytest.raises(ValueError): - NumericalTarget( - name="missing_bounds", - mode="MATCH", - ) - - def test_incompatible_transformation_for_match_mode(self): - with pytest.raises(ValueError): - NumericalTarget( - name="incompatible_transform", - mode="MATCH", - bounds=(0, 100), - transformation="LINEAR", - ) - - def test_invalid_transformation(self): - with pytest.raises(ValueError): - NumericalTarget( - name="invalid_transform", - mode="MATCH", - bounds=(0, 100), - transformation="SOME_STUFF", - ) - - class TestInvalidObjectiveCreation: """Invalid objective creation raises expected error.""" diff --git a/tests/validation/test_interval_validation.py b/tests/validation/test_interval_validation.py new file mode 100644 index 000000000..ae5ea3ba6 --- /dev/null +++ b/tests/validation/test_interval_validation.py @@ -0,0 +1,31 @@ +"""Validation tests for intervals.""" + +import pytest +from pytest import param + +from baybe.utils.interval import Interval + + +@pytest.mark.parametrize( + "bounds", + [ + param((0.0, 0.0), id="single_element"), + param((1.0, 0.0), id="unsorted_bounds"), + ], +) +def test_invalid_range(bounds): + """Providing a non-increasing pair of floats raises an exception.""" + with pytest.raises(ValueError): + Interval(*bounds) + + +@pytest.mark.parametrize( + "bounds", + [ + param(("a", 0.0), id="string"), + ], +) +def test_invalid_types(bounds): + """Providing an invalid bound type raises an exception.""" + with pytest.raises(ValueError): + Interval(*bounds) diff --git a/tests/validation/test_parameter_validation.py b/tests/validation/test_parameter_validation.py index 1b787a26d..f0914dcb7 100644 --- a/tests/validation/test_parameter_validation.py +++ b/tests/validation/test_parameter_validation.py @@ -10,7 +10,6 @@ from pytest import param from baybe.parameters.categorical import ( - CategoricalEncoding, CategoricalParameter, TaskParameter, ) @@ -19,7 +18,7 @@ NumericalContinuousParameter, NumericalDiscreteParameter, ) -from baybe.parameters.substance import SubstanceEncoding, SubstanceParameter +from baybe.parameters.substance import SubstanceParameter from baybe.parameters.validation import validate_decorrelation from baybe.utils.interval import InfiniteIntervalError @@ -86,12 +85,6 @@ def test_invalid_bounds_numerical_continuous_parameter(bounds): NumericalContinuousParameter(name="invalid_values", bounds=bounds) -@pytest.mark.parametrize("encoding", [e.name for e in CategoricalEncoding]) -def test_string_encoding_categorical_parameter(encoding): - """The encoding can also be specified as a string instead of an enum value.""" - CategoricalParameter(name="string_encoding", values=["A", "B"], encoding=encoding) - - @pytest.mark.parametrize( ("values", "error"), [ @@ -131,14 +124,6 @@ def test_invalid_values_task_parameter(values, active_values, error): TaskParameter(name="invalid_values", values=values, active_values=active_values) -@pytest.mark.parametrize("encoding", [e.name for e in SubstanceEncoding]) -def test_string_encoding_substance_parameter(encoding): - """The encoding can also be specified as a string instead of an enum value.""" - SubstanceParameter( - name="string_encoding", data={"A": "C", "B": "CC"}, encoding=encoding - ) - - @pytest.mark.parametrize( ("data", "error"), [ diff --git a/tests/validation/test_target_validation.py b/tests/validation/test_target_validation.py new file mode 100644 index 000000000..8b7ff7955 --- /dev/null +++ b/tests/validation/test_target_validation.py @@ -0,0 +1,36 @@ +"""Validation tests for targets.""" + +import pytest +from pytest import param + +from baybe.targets.numerical import NumericalTarget + + +@pytest.mark.parametrize( + ("mode", "bounds"), + [ + param("MATCH", None, id="non_closed_match_mode"), + param("MAX", (0, None), id="half_open"), + ], +) +def test_invalid_bounds_mode(mode, bounds): + """Providing invalid bounds raises an exception.""" + with pytest.raises(ValueError): + NumericalTarget(name="invalid_bounds", mode=mode, bounds=bounds) + + +@pytest.mark.parametrize( + ("mode", "bounds", "transformation"), + [ + param("MIN", None, "BELL", id="bell_for_min"), + param("MATCH", (0, 1), "LINEAR", id="linear_for_match"), + ], +) +def test_incompatible_transform_mode(mode, bounds, transformation): + with pytest.raises(ValueError): + NumericalTarget( + name="incompatible_transform", + mode=mode, + bounds=bounds, + transformation=transformation, + )