diff --git a/CHANGELOG.md b/CHANGELOG.md index 54711d9fc..747e1373d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `mypy` for search space and objectives - Class hierarchy for objectives - Deserialization is now also possible from optional class name abbreviations -- `Kernel` base class allowing to specify kernels -- `MaternKernel` class can be chosen for GP surrogates -- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives, priors - and acquisition functions +- `Kernel`, `MaternKernel`, and `ScaleKernel` classes for specifying kernels +- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives, + priors and acquisition functions - New acquisition functions: `qSR`, `qNEI`, `LogEI`, `qLogEI`, `qLogNEI` - Serialization user guide - Basic deserialization tests using different class type specifiers diff --git a/baybe/kernels/__init__.py b/baybe/kernels/__init__.py index 95d66f7f6..b9a5ecb12 100644 --- a/baybe/kernels/__init__.py +++ b/baybe/kernels/__init__.py @@ -1,5 +1,8 @@ """Kernels for Gaussian process surrogate models.""" -from baybe.kernels.basic import MaternKernel +from baybe.kernels.basic import MaternKernel, ScaleKernel -__all__ = ["MaternKernel"] +__all__ = [ + "MaternKernel", + "ScaleKernel", +] diff --git a/baybe/kernels/base.py b/baybe/kernels/base.py index 9492c2444..a01a5a9d7 100644 --- a/baybe/kernels/base.py +++ b/baybe/kernels/base.py @@ -1,9 +1,11 @@ """Base classes for all kernels.""" +from __future__ import annotations + from abc import ABC -from typing import Optional +from typing import TYPE_CHECKING, Optional -from attrs import define, field +from attrs import define from baybe.kernels.priors.base import Prior from baybe.serialization.core import ( @@ -12,31 +14,76 @@ unstructure_base, ) from baybe.serialization.mixin import SerialMixin -from baybe.utils.basic import filter_attributes +from baybe.utils.basic import filter_attributes, get_baseclasses + +if TYPE_CHECKING: + import torch @define(frozen=True) class Kernel(ABC, SerialMixin): """Abstract base class for all kernels.""" - lengthscale_prior: Optional[Prior] = field(default=None, kw_only=True) - """An optional prior on the kernel lengthscale.""" - - def to_gpytorch(self, *args, **kwargs): + def to_gpytorch( + self, + *, + ard_num_dims: Optional[int] = None, + batch_shape: Optional[torch.Size] = None, + active_dims: Optional[tuple[int, ...]] = None, + ): """Create the gpytorch representation of the kernel.""" import gpytorch.kernels + # Fetch the necessary gpytorch constructor parameters of the kernel. + # NOTE: In gpytorch, some attributes (like the kernel lengthscale) are handled + # via the `gpytorch.kernels.Kernel` base class. Hence, it is not sufficient to + # just check the fields of the actual class, but also those of the base class. kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__) - fields_dict = filter_attributes(object=self, callable_=kernel_cls.__init__) + base_classes = get_baseclasses(kernel_cls, abstract=True) + fields_dict = {} + for cls in [kernel_cls, *base_classes]: + fields_dict.update(filter_attributes(object=self, callable_=cls.__init__)) + + # Convert specified priors to gpytorch, if provided + prior_dict = { + key: value.to_gpytorch() + for key, value in fields_dict.items() + if isinstance(value, Prior) + } + + # Convert specified inner kernels to gpytorch, if provided + kernel_dict = { + key: value.to_gpytorch( + ard_num_dims=ard_num_dims, + batch_shape=batch_shape, + active_dims=active_dims, + ) + for key, value in fields_dict.items() + if isinstance(value, Kernel) + } + + # Create the kernel with all its inner gpytorch objects + fields_dict.update(kernel_dict) + fields_dict.update(prior_dict) + gpytorch_kernel = kernel_cls(**fields_dict) + + # If the kernel has a lengthscale, set its initial value + if kernel_cls.has_lengthscale: + import torch - # If a lengthscale prior was chosen, we manually add it to the dictionary - if self.lengthscale_prior is not None: - fields_dict["lengthscale_prior"] = self.lengthscale_prior.to_gpytorch() + from baybe.utils.torch import DTypeFloatTorch - # Update kwargs to contain class-specific attributes - kwargs.update(fields_dict) + # We can ignore mypy here and simply assume that the corresponding BayBE + # kernel class has the necessary lengthscale attribute defined. This is + # safer than using a `hasattr` check in the above if-condition since for + # the latter the code would silently fail when forgetting to add the + # attribute to a new kernel class / misspelling it. + if (initial_value := self.lengthscale_initial_value) is not None: # type: ignore[attr-defined] + gpytorch_kernel.lengthscale = torch.tensor( + initial_value, dtype=DTypeFloatTorch + ) - return kernel_cls(*args, **kwargs) + return gpytorch_kernel # Register de-/serialization hooks diff --git a/baybe/kernels/basic.py b/baybe/kernels/basic.py index 32d0df07a..d4083bc54 100644 --- a/baybe/kernels/basic.py +++ b/baybe/kernels/basic.py @@ -1,35 +1,16 @@ """Collection of kernels.""" -from fractions import Fraction -from typing import Union +from typing import Optional from attrs import define, field -from attrs.validators import in_ +from attrs.converters import optional as optional_c +from attrs.validators import in_, instance_of +from attrs.validators import optional as optional_v from baybe.kernels.base import Kernel - - -def _convert_fraction(value: Union[str, float, Fraction], /) -> float: - """Convert the provided value into a float. - - Args: - value: The parameter that should be converted. - - Returns: - The float representation of the given input. - - Raises: - ValueError: If the input was provided as string but could not be interpreted as - fraction. - """ - if isinstance(value, str): - try: - value = Fraction(value) - except ValueError as err: - raise ValueError( - f"The provided input '{value}' could not be interpreted as a fraction." - ) from err - return float(value) +from baybe.kernels.priors.base import Prior +from baybe.utils.conversion import fraction_to_float +from baybe.utils.validation import finite_float @define(frozen=True) @@ -37,9 +18,50 @@ class MaternKernel(Kernel): """A Matern kernel using a smoothness parameter.""" nu: float = field( - converter=_convert_fraction, validator=in_([0.5, 1.5, 2.5]), default=2.5 + converter=fraction_to_float, validator=in_([0.5, 1.5, 2.5]), default=2.5 ) """A smoothness parameter. Only takes the values 0.5, 1.5 or 2.5. Larger values yield smoother interpolations. """ + + lengthscale_prior: Optional[Prior] = field( + default=None, validator=optional_v(instance_of(Prior)) + ) + """An optional prior on the kernel lengthscale.""" + + lengthscale_initial_value: Optional[float] = field( + default=None, converter=optional_c(float), validator=optional_v(finite_float) + ) + """An optional initial value for the kernel lengthscale.""" + + +@define(frozen=True) +class ScaleKernel(Kernel): + """A kernel for decorating existing kernels with an outputscale.""" + + base_kernel: Kernel = field(validator=instance_of(Kernel)) + """The base kernel that is being decorated.""" + + outputscale_prior: Optional[Prior] = field( + default=None, validator=optional_v(instance_of(Prior)) + ) + """An optional prior on the output scale.""" + + outputscale_initial_value: Optional[float] = field( + default=None, converter=optional_c(float), validator=optional_v(finite_float) + ) + """An optional initial value for the output scale.""" + + def to_gpytorch(self, *args, **kwargs): # noqa: D102 + # See base class. + import torch + + from baybe.utils.torch import DTypeFloatTorch + + gpytorch_kernel = super().to_gpytorch(*args, **kwargs) + if (initial_value := self.outputscale_initial_value) is not None: + gpytorch_kernel.outputscale = torch.tensor( + initial_value, dtype=DTypeFloatTorch + ) + return gpytorch_kernel diff --git a/baybe/surrogates/gaussian_process.py b/baybe/surrogates/gaussian_process.py index dce9c9a88..ec462a692 100644 --- a/baybe/surrogates/gaussian_process.py +++ b/baybe/surrogates/gaussian_process.py @@ -6,7 +6,7 @@ from attr import define, field -from baybe.kernels import MaternKernel +from baybe.kernels import MaternKernel, ScaleKernel from baybe.kernels.base import Kernel from baybe.kernels.priors import GammaPrior from baybe.searchspace import SearchSpace @@ -108,25 +108,21 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No # If no kernel is provided, we construct one from our priors if self.kernel is None: - self.kernel = MaternKernel(lengthscale_prior=lengthscale_prior[0]) + self.kernel = ScaleKernel( + base_kernel=MaternKernel( + lengthscale_prior=lengthscale_prior[0], + lengthscale_initial_value=lengthscale_prior[1], + ), + outputscale_prior=outputscale_prior[0], + outputscale_initial_value=outputscale_prior[1], + ) # define the covariance module for the numeric dimensions - gpytorch_kernel = self.kernel.to_gpytorch( + base_covar_module = self.kernel.to_gpytorch( ard_num_dims=train_x.shape[-1] - n_task_params, active_dims=numeric_idxs, batch_shape=batch_shape, ) - base_covar_module = gpytorch.kernels.ScaleKernel( - gpytorch_kernel, - batch_shape=batch_shape, - outputscale_prior=outputscale_prior[0].to_gpytorch(), - ) - if outputscale_prior[1] is not None: - base_covar_module.outputscale = torch.tensor([outputscale_prior[1]]) - if lengthscale_prior[1] is not None: - base_covar_module.base_kernel.lengthscale = torch.tensor( - [lengthscale_prior[1]] - ) # create GP covariance if task_idx is None: diff --git a/baybe/utils/basic.py b/baybe/utils/basic.py index e2b95715f..2e9f63186 100644 --- a/baybe/utils/basic.py +++ b/baybe/utils/basic.py @@ -54,6 +54,37 @@ def get_subclasses(cls: _C, recursive: bool = True, abstract: bool = False) -> l return subclasses +def get_baseclasses( + cls: type, + recursive: bool = True, + abstract: bool = False, +) -> list[type]: + """Return a list of base classes for the given class. + + Args: + cls: The class to retrieve base classes for. + recursive: If ``True``, indirect base classes (i.e., base classes of base + classes) are included. + abstract: If `True`, abstract base classes are included. + + Returns: + A list of base classes for the given class. + """ + from baybe.utils.boolean import is_abstract + + classes = [] + + for baseclass in cls.__bases__: + if baseclass not in classes: + if abstract or not is_abstract(baseclass): + classes.append(baseclass) + + if recursive: + classes.extend(get_baseclasses(baseclass, abstract=abstract)) + + return classes + + def set_random_seed(seed: int): """Set the global random seed. diff --git a/baybe/utils/conversion.py b/baybe/utils/conversion.py new file mode 100644 index 000000000..bc9c47476 --- /dev/null +++ b/baybe/utils/conversion.py @@ -0,0 +1,27 @@ +"""Conversion utilities.""" + +from fractions import Fraction +from typing import Union + + +def fraction_to_float(value: Union[str, float, Fraction], /) -> float: + """Convert the provided input representing a fraction into a float. + + Args: + value: The input to be converted. + + Returns: + The float representation of the given input. + + Raises: + ValueError: If the input was provided as string but could not be interpreted as + fraction. + """ + if isinstance(value, str): + try: + value = Fraction(value) + except ValueError as err: + raise ValueError( + f"The provided input '{value}' could not be interpreted as a fraction." + ) from err + return float(value) diff --git a/tests/hypothesis_strategies/kernels.py b/tests/hypothesis_strategies/kernels.py index 2eaff2d1d..de5387d15 100644 --- a/tests/hypothesis_strategies/kernels.py +++ b/tests/hypothesis_strategies/kernels.py @@ -2,13 +2,36 @@ import hypothesis.strategies as st -from baybe.kernels import MaternKernel +from baybe.kernels import MaternKernel, ScaleKernel +from ..hypothesis_strategies.basic import finite_floats from ..hypothesis_strategies.priors import priors matern_kernels = st.builds( MaternKernel, nu=st.sampled_from((0.5, 1.5, 2.5)), lengthscale_prior=st.one_of(st.none(), priors), + lengthscale_initial_value=st.one_of(st.none(), finite_floats()), ) """A strategy that generates Matern kernels.""" + + +base_kernels = st.one_of([matern_kernels]) +"""A strategy that generates base kernels to be used within more complex kernels.""" + + +@st.composite +def kernels(draw: st.DrawFn): + """Generate :class:`baybe.kernels.basic.Kernel`.""" + base_kernel = draw(base_kernels) + add_scale = draw(st.booleans()) + if add_scale: + return ScaleKernel( + base_kernel=base_kernel, + outputscale_prior=draw(st.one_of(st.none(), priors)), + outputscale_initial_value=draw( + st.one_of(st.none(), finite_floats()), + ), + ) + else: + return base_kernel diff --git a/tests/serialization/test_kernel_serialization.py b/tests/serialization/test_kernel_serialization.py index 4644fe94d..35145027b 100644 --- a/tests/serialization/test_kernel_serialization.py +++ b/tests/serialization/test_kernel_serialization.py @@ -2,12 +2,12 @@ from hypothesis import given -from baybe.kernels import MaternKernel -from tests.hypothesis_strategies.kernels import matern_kernels +from baybe.kernels.base import Kernel +from tests.hypothesis_strategies.kernels import kernels -@given(matern_kernels) -def test_matern_kernel_roundtrip(kernel: MaternKernel): +@given(kernels()) +def test_kernel_roundtrip(kernel: Kernel): string = kernel.to_json() - kernel2 = MaternKernel.from_json(string) + kernel2 = Kernel.from_json(string) assert kernel == kernel2, (kernel, kernel2) diff --git a/tests/test_iterations.py b/tests/test_iterations.py index 98b1e50c4..1077f030f 100644 --- a/tests/test_iterations.py +++ b/tests/test_iterations.py @@ -5,6 +5,7 @@ import pytest from baybe.acquisition.base import AcquisitionFunction +from baybe.kernels import MaternKernel, ScaleKernel from baybe.kernels.priors import ( GammaPrior, HalfCauchyPrior, @@ -130,6 +131,16 @@ SmoothedBoxPrior(0, 3, 0.1), ] +valid_base_kernels = [MaternKernel(lengthscale_prior=prior) for prior in valid_priors] + +valid_scale_kernels = [ + ScaleKernel(base_kernel=base_kernel, outputscale_prior=prior) + for base_kernel in valid_base_kernels + for prior in valid_priors +] + +valid_kernels = valid_base_kernels + valid_scale_kernels + test_targets = [ ["Target_max"], ["Target_min"], @@ -171,6 +182,15 @@ def test_iter_prior(campaign, n_iterations, batch_size): run_iterations(campaign, n_iterations, batch_size) +@pytest.mark.slow +@pytest.mark.parametrize( + "kernel", valid_kernels, ids=[c.__class__ for c in valid_kernels] +) +@pytest.mark.parametrize("n_iterations", [3], ids=["i3"]) +def test_iter_kernel(campaign, n_iterations, batch_size): + run_iterations(campaign, n_iterations, batch_size) + + @pytest.mark.slow @pytest.mark.parametrize( "surrogate_model",