diff --git a/bofire/kernels/categorical.py b/bofire/kernels/categorical.py index 640a62d1..8659db6f 100644 --- a/bofire/kernels/categorical.py +++ b/bofire/kernels/categorical.py @@ -1,24 +1,65 @@ +from typing import Dict + import torch +from botorch.models.transforms.input import OneHotToNumeric from gpytorch.kernels.kernel import Kernel from torch import Tensor class HammingKernelWithOneHots(Kernel): + r""" + A Kernel for one-hot enocded categorical features. The inputs + may contain more than one categorical feature. + + Computes `exp(-dist(x1, x2) / lengthscale)`, where + `dist(x1, x2)` is zero if `x1` and `x2` correspond to the + same category, and one otherwise. If the last dimension + is not a batch dimension, then the mean is considered. + + Note: This kernel is NOT differentiable w.r.t. the inputs. + """ + has_lengthscale = True + def __init__(self, categorical_features: Dict[int, int], *args, **kwargs): + """ + Initialize. + + Args: + categorical_features: A dictionary mapping the starting index of each + categorical feature to its cardinality. This assumes that categoricals + are one-hot encoded. + *args, **kwargs: Passed to gpytorch.kernels.kernel.Kernel + """ + super().__init__(*args, **kwargs) + + onehot_dim = sum(categorical_features.values()) + self.trx = OneHotToNumeric( + onehot_dim, categorical_features=categorical_features + ) + def forward( self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, + **params, ) -> Tensor: - delta = (x1.unsqueeze(-2) - x2.unsqueeze(-3)) ** 2 - dists = delta / self.lengthscale.unsqueeze(-2) + x1 = self.trx(x1) + x2 = self.trx(x2) + + delta = x1.unsqueeze(-2) != x2.unsqueeze(-3) + if self.ard_num_dims is not None: + ls = self.lengthscale[..., : delta.shape[-1]] + else: + ls = self.lengthscale + + dists = delta / ls.unsqueeze(-2) if last_dim_is_batch: dists = dists.transpose(-3, -1) - - dists = dists.sum(-1) / 2 + else: + dists = dists.mean(-1) res = torch.exp(-dists) if diag: res = torch.diagonal(res, dim1=-1, dim2=-2) diff --git a/bofire/kernels/mapper.py b/bofire/kernels/mapper.py index d0e6d3d5..ad8519ac 100644 --- a/bofire/kernels/mapper.py +++ b/bofire/kernels/mapper.py @@ -218,22 +218,35 @@ def map_HammingDistanceKernel( active_dims: List[int], features_to_idx_mapper: Optional[Callable[[List[str]], List[int]]], ) -> GpytorchKernel: - active_dims = _compute_active_dims(data_model, active_dims, features_to_idx_mapper) + if data_model.features is not None: + if features_to_idx_mapper is None: + raise RuntimeError( + "features_to_idx_mapper must be defined when using only a subset of features" + ) - with_one_hots = data_model.features is not None and len(active_dims) > 1 - if with_one_hots and len(active_dims) == 1: - raise RuntimeError( - "only one feature for categorical kernel operating on one-hot features" - ) - elif not with_one_hots and len(active_dims) > 1: - # this is not necessarily an issue since botorch's CategoricalKernel - # can work on multiple features at the same time - pass + active_dims = [] + categorical_features = {} + for k in data_model.features: + idx = features_to_idx_mapper([k]) + categorical_features[len(active_dims)] = len(idx) + + already_used = [i for i in idx if i in active_dims] + if already_used: + raise RuntimeError( + f"indices {already_used} are used in more than one categorical feature" + ) + + active_dims.extend(idx) + + if len(idx) == 1: + raise RuntimeError( + f"feature {k} is supposed to be one-hot encoded but is mapped to a single dimension" + ) - if with_one_hots: return HammingKernelWithOneHots( - batch_shape=batch_shape, + categorical_features=categorical_features, ard_num_dims=len(active_dims) if data_model.ard else None, + batch_shape=batch_shape, active_dims=active_dims, # type: ignore lengthscale_constraint=GreaterThan(1e-06), ) diff --git a/tests/bofire/kernels/test_categorical.py b/tests/bofire/kernels/test_categorical.py index 3b7ef5b4..20a81d9f 100644 --- a/tests/bofire/kernels/test_categorical.py +++ b/tests/bofire/kernels/test_categorical.py @@ -5,14 +5,76 @@ from bofire.kernels.categorical import HammingKernelWithOneHots -def test_hamming_with_one_hot(): +def test_hamming_with_one_hot_one_feature(): + cat = {0: 3} + k1 = CategoricalKernel() - k2 = HammingKernelWithOneHots() + k2 = HammingKernelWithOneHots(categorical_features=cat) xin_oh = torch.eye(3) - xin_cat = OneHotToNumeric(3, categorical_features={0: 3}).transform(xin_oh) + xin_cat = OneHotToNumeric(3, categorical_features=cat).transform(xin_oh) z1 = k1(xin_cat).to_dense() z2 = k2(xin_oh).to_dense() + assert z1.shape == z2.shape == (3, 3) + assert torch.allclose(z1, z2) + + +def test_hamming_with_one_hot_two_features(): + cat = {0: 2, 2: 4} + + k1 = CategoricalKernel() + k2 = HammingKernelWithOneHots(categorical_features=cat) + + xin_oh = torch.zeros(4, 6) + xin_oh[:2, :2] = xin_oh[2:, :2] = torch.eye(2) + xin_oh[:, 2:] = torch.eye(4) + + xin_cat = OneHotToNumeric(6, categorical_features=cat).transform(xin_oh) + + z1 = k1(xin_cat).to_dense() + z2 = k2(xin_oh).to_dense() + + assert z1.shape == z2.shape == (4, 4) + assert torch.allclose(z1, z2) + + +def test_hamming_with_one_hot_two_features_and_lengthscales(): + cat = {0: 2, 2: 4} + + k1 = CategoricalKernel(ard_num_dims=2) + k1.lengthscale = torch.tensor([1.5, 3.0]) + + # botorch will check that the lengthscale for ARD has the same number of elements as the one-hotted inputs, + # so we have to specify the ard_num_dims accordingly. The kernel will make sure to only use the right + # number of elements, corresponding to the number of categorical features. + k2 = HammingKernelWithOneHots(categorical_features=cat, ard_num_dims=6) + k2.lengthscale = torch.tensor([1.5, 3.0, 0.0, 0.0, 0.0, 0.0]) + + xin_oh = torch.zeros(4, 6) + xin_oh[:2, :2] = xin_oh[2:, :2] = torch.eye(2) + xin_oh[:, 2:] = torch.eye(4) + + xin_cat = OneHotToNumeric(6, categorical_features=cat).transform(xin_oh) + + z1 = k1(xin_cat).to_dense() + z2 = k2(xin_oh).to_dense() + + assert z1.shape == z2.shape == (4, 4) + assert torch.allclose(z1, z2) + + +def test_feature_order(): + x1_in = torch.zeros(4, 2) + x1_in[:2, :] = x1_in[2:, :] = torch.eye(2) + x2_in = torch.eye(4) + + k1 = HammingKernelWithOneHots(categorical_features={0: 2, 2: 4}) + k2 = HammingKernelWithOneHots(categorical_features={0: 4, 4: 2}) + + z1 = k1(torch.cat([x1_in, x2_in], dim=1)).to_dense() + z2 = k2(torch.cat([x2_in, x1_in], dim=1)).to_dense() + + assert z1.shape == z2.shape == (4, 4) assert torch.allclose(z1, z2) diff --git a/tests/bofire/kernels/test_mapper.py b/tests/bofire/kernels/test_mapper.py index fec7d834..73bba933 100644 --- a/tests/bofire/kernels/test_mapper.py +++ b/tests/bofire/kernels/test_mapper.py @@ -21,6 +21,7 @@ WassersteinKernel, ) from bofire.data_models.priors.api import THREESIX_SCALE_PRIOR, GammaPrior +from bofire.kernels.categorical import HammingKernelWithOneHots from tests.bofire.data_models.specs.api import Spec @@ -250,3 +251,162 @@ def test_map_wasserstein_kernel(): ) assert k.squared is True assert hasattr(k, "lengthscale_prior") is False + + +def test_map_HammingDistanceKernel_to_onehot_with_ard(): + fmap = { + "x_cat_1": [5, 6, 7, 8], + "x_cat_2": [2, 3], + } + + k_mapped = kernels.map( + HammingDistanceKernel( + ard=True, + features=["x_cat_1", "x_cat_2"], + ), + batch_shape=torch.Size(), + ard_num_dims=10, + active_dims=list(range(5)), + features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]], + ) + + assert isinstance(k_mapped, HammingKernelWithOneHots) + assert k_mapped.active_dims.tolist() == [5, 6, 7, 8, 2, 3] + assert k_mapped.ard_num_dims == 6 + assert k_mapped.lengthscale.shape == (1, 6) + assert k_mapped.trx.categorical_features == {0: 4, 4: 2} + + +def test_map_HammingDistanceKernel_to_onehot_without_ard(): + fmap = { + "x_cat_1": [5, 6, 7, 8], + "x_cat_2": [2, 3], + } + + k_mapped = kernels.map( + HammingDistanceKernel( + ard=False, + features=["x_cat_1", "x_cat_2"], + ), + batch_shape=torch.Size(), + ard_num_dims=10, + active_dims=list(range(5)), + features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]], + ) + + assert isinstance(k_mapped, HammingKernelWithOneHots) + assert k_mapped.active_dims.tolist() == [5, 6, 7, 8, 2, 3] + assert k_mapped.ard_num_dims is None + assert k_mapped.lengthscale.shape == (1, 1) + assert k_mapped.trx.categorical_features == {0: 4, 4: 2} + + +def test_map_HammingDistanceKernel_to_categorical_without_ard(): + k_mapped = kernels.map( + HammingDistanceKernel( + ard=False, + ), + batch_shape=torch.Size(), + ard_num_dims=10, + active_dims=list(range(5)), + features_to_idx_mapper=None, + ) + + assert isinstance(k_mapped, CategoricalKernel) + assert k_mapped.active_dims.tolist() == [0, 1, 2, 3, 4] + assert k_mapped.ard_num_dims is None + assert k_mapped.lengthscale.shape == (1, 1) + + +def test_map_HammingDistanceKernel_to_categorical_with_ard(): + k_mapped = kernels.map( + HammingDistanceKernel( + ard=True, + ), + batch_shape=torch.Size(), + ard_num_dims=10, + active_dims=list(range(5)), + features_to_idx_mapper=None, + ) + + assert isinstance(k_mapped, CategoricalKernel) + assert k_mapped.active_dims.tolist() == [0, 1, 2, 3, 4] + assert k_mapped.ard_num_dims == 5 + assert k_mapped.lengthscale.shape == (1, 5) + + +def test_map_HammingDistanceKernel_to_onehot_checks_dimension_overlap(): + fmap = { + "x_cat_1": [3, 4], + "x_cat_2": [2, 3], + } + + with pytest.raises(RuntimeError): + kernels.map( + HammingDistanceKernel( + ard=True, + features=["x_cat_1", "x_cat_2"], + ), + batch_shape=torch.Size(), + ard_num_dims=10, + active_dims=list(range(5)), + features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]], + ) + + +def test_map_HammingDistanceKernel_to_onehot_checks_onehot_encoding(): + fmap = { + "x_cat_1": [4], + "x_cat_2": [2, 3], + } + + with pytest.raises(RuntimeError): + kernels.map( + HammingDistanceKernel( + ard=True, + features=["x_cat_1", "x_cat_2"], + ), + batch_shape=torch.Size(), + ard_num_dims=10, + active_dims=list(range(5)), + features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]], + ) + + +def test_map_multiple_kernels_on_feature_subsets(): + fmap = { + "x_1": [0], + "x_2": [1], + "x_cat_1": [2, 3], + "x_cat_2": [4, 5], + } + + k_mapped = kernels.map( + AdditiveKernel( + kernels=[ + HammingDistanceKernel( + ard=True, + features=["x_cat_1", "x_cat_2"], + ), + RBFKernel( + features=["x_1", "x_2"], + ), + ] + ), + batch_shape=torch.Size(), + ard_num_dims=10, + active_dims=list(range(5)), + features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]], + ) + + assert len(k_mapped.kernels) == 2 + + assert isinstance(k_mapped.kernels[0], HammingKernelWithOneHots) + assert k_mapped.kernels[0].active_dims.tolist() == [2, 3, 4, 5] + assert k_mapped.kernels[0].ard_num_dims == 4 + + from gpytorch.kernels import RBFKernel as GpytorchRBFKernel + + assert isinstance(k_mapped.kernels[1], GpytorchRBFKernel) + assert k_mapped.kernels[1].active_dims.tolist() == [0, 1] + assert k_mapped.kernels[1].ard_num_dims == 2