Skip to content

Commit

Permalink
categorical onehot kernel uses the right lengthscale for multiple fea…
Browse files Browse the repository at this point in the history
…tures
  • Loading branch information
e-dorigatti committed Jan 13, 2025
1 parent 561ac20 commit 1867e7b
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 19 deletions.
49 changes: 45 additions & 4 deletions bofire/kernels/categorical.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,65 @@
from typing import Dict

import torch
from botorch.models.transforms.input import OneHotToNumeric
from gpytorch.kernels.kernel import Kernel
from torch import Tensor


class HammingKernelWithOneHots(Kernel):
r"""
A Kernel for one-hot enocded categorical features. The inputs
may contain more than one categorical feature.
Computes `exp(-dist(x1, x2) / lengthscale)`, where
`dist(x1, x2)` is zero if `x1` and `x2` correspond to the
same category, and one otherwise. If the last dimension
is not a batch dimension, then the mean is considered.
Note: This kernel is NOT differentiable w.r.t. the inputs.
"""

has_lengthscale = True

def __init__(self, categorical_features: Dict[int, int], *args, **kwargs):
"""
Initialize.
Args:
categorical_features: A dictionary mapping the starting index of each
categorical feature to its cardinality. This assumes that categoricals
are one-hot encoded.
*args, **kwargs: Passed to gpytorch.kernels.kernel.Kernel
"""
super().__init__(*args, **kwargs)

onehot_dim = sum(categorical_features.values())
self.trx = OneHotToNumeric(
onehot_dim, categorical_features=categorical_features
)

def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
**params,
) -> Tensor:
delta = (x1.unsqueeze(-2) - x2.unsqueeze(-3)) ** 2
dists = delta / self.lengthscale.unsqueeze(-2)
x1 = self.trx(x1)
x2 = self.trx(x2)

delta = x1.unsqueeze(-2) != x2.unsqueeze(-3)
if self.ard_num_dims is not None:
ls = self.lengthscale[..., : delta.shape[-1]]
else:
ls = self.lengthscale

dists = delta / ls.unsqueeze(-2)
if last_dim_is_batch:
dists = dists.transpose(-3, -1)

dists = dists.sum(-1) / 2
else:
dists = dists.mean(-1)
res = torch.exp(-dists)
if diag:
res = torch.diagonal(res, dim1=-1, dim2=-2)
Expand Down
37 changes: 25 additions & 12 deletions bofire/kernels/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,22 +218,35 @@ def map_HammingDistanceKernel(
active_dims: List[int],
features_to_idx_mapper: Optional[Callable[[List[str]], List[int]]],
) -> GpytorchKernel:
active_dims = _compute_active_dims(data_model, active_dims, features_to_idx_mapper)
if data_model.features is not None:
if features_to_idx_mapper is None:
raise RuntimeError(
"features_to_idx_mapper must be defined when using only a subset of features"
)

with_one_hots = data_model.features is not None and len(active_dims) > 1
if with_one_hots and len(active_dims) == 1:
raise RuntimeError(
"only one feature for categorical kernel operating on one-hot features"
)
elif not with_one_hots and len(active_dims) > 1:
# this is not necessarily an issue since botorch's CategoricalKernel
# can work on multiple features at the same time
pass
active_dims = []
categorical_features = {}
for k in data_model.features:
idx = features_to_idx_mapper([k])
categorical_features[len(active_dims)] = len(idx)

already_used = [i for i in idx if i in active_dims]
if already_used:
raise RuntimeError(
f"indices {already_used} are used in more than one categorical feature"
)

active_dims.extend(idx)

if len(idx) == 1:
raise RuntimeError(
f"feature {k} is supposed to be one-hot encoded but is mapped to a single dimension"
)

if with_one_hots:
return HammingKernelWithOneHots(
batch_shape=batch_shape,
categorical_features=categorical_features,
ard_num_dims=len(active_dims) if data_model.ard else None,
batch_shape=batch_shape,
active_dims=active_dims, # type: ignore
lengthscale_constraint=GreaterThan(1e-06),
)
Expand Down
68 changes: 65 additions & 3 deletions tests/bofire/kernels/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,76 @@
from bofire.kernels.categorical import HammingKernelWithOneHots


def test_hamming_with_one_hot():
def test_hamming_with_one_hot_one_feature():
cat = {0: 3}

k1 = CategoricalKernel()
k2 = HammingKernelWithOneHots()
k2 = HammingKernelWithOneHots(categorical_features=cat)

xin_oh = torch.eye(3)
xin_cat = OneHotToNumeric(3, categorical_features={0: 3}).transform(xin_oh)
xin_cat = OneHotToNumeric(3, categorical_features=cat).transform(xin_oh)

z1 = k1(xin_cat).to_dense()
z2 = k2(xin_oh).to_dense()

assert z1.shape == z2.shape == (3, 3)
assert torch.allclose(z1, z2)


def test_hamming_with_one_hot_two_features():
cat = {0: 2, 2: 4}

k1 = CategoricalKernel()
k2 = HammingKernelWithOneHots(categorical_features=cat)

xin_oh = torch.zeros(4, 6)
xin_oh[:2, :2] = xin_oh[2:, :2] = torch.eye(2)
xin_oh[:, 2:] = torch.eye(4)

xin_cat = OneHotToNumeric(6, categorical_features=cat).transform(xin_oh)

z1 = k1(xin_cat).to_dense()
z2 = k2(xin_oh).to_dense()

assert z1.shape == z2.shape == (4, 4)
assert torch.allclose(z1, z2)


def test_hamming_with_one_hot_two_features_and_lengthscales():
cat = {0: 2, 2: 4}

k1 = CategoricalKernel(ard_num_dims=2)
k1.lengthscale = torch.tensor([1.5, 3.0])

# botorch will check that the lengthscale for ARD has the same number of elements as the one-hotted inputs,
# so we have to specify the ard_num_dims accordingly. The kernel will make sure to only use the right
# number of elements, corresponding to the number of categorical features.
k2 = HammingKernelWithOneHots(categorical_features=cat, ard_num_dims=6)
k2.lengthscale = torch.tensor([1.5, 3.0, 0.0, 0.0, 0.0, 0.0])

xin_oh = torch.zeros(4, 6)
xin_oh[:2, :2] = xin_oh[2:, :2] = torch.eye(2)
xin_oh[:, 2:] = torch.eye(4)

xin_cat = OneHotToNumeric(6, categorical_features=cat).transform(xin_oh)

z1 = k1(xin_cat).to_dense()
z2 = k2(xin_oh).to_dense()

assert z1.shape == z2.shape == (4, 4)
assert torch.allclose(z1, z2)


def test_feature_order():
x1_in = torch.zeros(4, 2)
x1_in[:2, :] = x1_in[2:, :] = torch.eye(2)
x2_in = torch.eye(4)

k1 = HammingKernelWithOneHots(categorical_features={0: 2, 2: 4})
k2 = HammingKernelWithOneHots(categorical_features={0: 4, 4: 2})

z1 = k1(torch.cat([x1_in, x2_in], dim=1)).to_dense()
z2 = k2(torch.cat([x2_in, x1_in], dim=1)).to_dense()

assert z1.shape == z2.shape == (4, 4)
assert torch.allclose(z1, z2)
160 changes: 160 additions & 0 deletions tests/bofire/kernels/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
WassersteinKernel,
)
from bofire.data_models.priors.api import THREESIX_SCALE_PRIOR, GammaPrior
from bofire.kernels.categorical import HammingKernelWithOneHots
from tests.bofire.data_models.specs.api import Spec


Expand Down Expand Up @@ -250,3 +251,162 @@ def test_map_wasserstein_kernel():
)
assert k.squared is True
assert hasattr(k, "lengthscale_prior") is False


def test_map_HammingDistanceKernel_to_onehot_with_ard():
fmap = {
"x_cat_1": [5, 6, 7, 8],
"x_cat_2": [2, 3],
}

k_mapped = kernels.map(
HammingDistanceKernel(
ard=True,
features=["x_cat_1", "x_cat_2"],
),
batch_shape=torch.Size(),
ard_num_dims=10,
active_dims=list(range(5)),
features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]],
)

assert isinstance(k_mapped, HammingKernelWithOneHots)
assert k_mapped.active_dims.tolist() == [5, 6, 7, 8, 2, 3]
assert k_mapped.ard_num_dims == 6
assert k_mapped.lengthscale.shape == (1, 6)
assert k_mapped.trx.categorical_features == {0: 4, 4: 2}


def test_map_HammingDistanceKernel_to_onehot_without_ard():
fmap = {
"x_cat_1": [5, 6, 7, 8],
"x_cat_2": [2, 3],
}

k_mapped = kernels.map(
HammingDistanceKernel(
ard=False,
features=["x_cat_1", "x_cat_2"],
),
batch_shape=torch.Size(),
ard_num_dims=10,
active_dims=list(range(5)),
features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]],
)

assert isinstance(k_mapped, HammingKernelWithOneHots)
assert k_mapped.active_dims.tolist() == [5, 6, 7, 8, 2, 3]
assert k_mapped.ard_num_dims is None
assert k_mapped.lengthscale.shape == (1, 1)
assert k_mapped.trx.categorical_features == {0: 4, 4: 2}


def test_map_HammingDistanceKernel_to_categorical_without_ard():
k_mapped = kernels.map(
HammingDistanceKernel(
ard=False,
),
batch_shape=torch.Size(),
ard_num_dims=10,
active_dims=list(range(5)),
features_to_idx_mapper=None,
)

assert isinstance(k_mapped, CategoricalKernel)
assert k_mapped.active_dims.tolist() == [0, 1, 2, 3, 4]
assert k_mapped.ard_num_dims is None
assert k_mapped.lengthscale.shape == (1, 1)


def test_map_HammingDistanceKernel_to_categorical_with_ard():
k_mapped = kernels.map(
HammingDistanceKernel(
ard=True,
),
batch_shape=torch.Size(),
ard_num_dims=10,
active_dims=list(range(5)),
features_to_idx_mapper=None,
)

assert isinstance(k_mapped, CategoricalKernel)
assert k_mapped.active_dims.tolist() == [0, 1, 2, 3, 4]
assert k_mapped.ard_num_dims == 5
assert k_mapped.lengthscale.shape == (1, 5)


def test_map_HammingDistanceKernel_to_onehot_checks_dimension_overlap():
fmap = {
"x_cat_1": [3, 4],
"x_cat_2": [2, 3],
}

with pytest.raises(RuntimeError):
kernels.map(
HammingDistanceKernel(
ard=True,
features=["x_cat_1", "x_cat_2"],
),
batch_shape=torch.Size(),
ard_num_dims=10,
active_dims=list(range(5)),
features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]],
)


def test_map_HammingDistanceKernel_to_onehot_checks_onehot_encoding():
fmap = {
"x_cat_1": [4],
"x_cat_2": [2, 3],
}

with pytest.raises(RuntimeError):
kernels.map(
HammingDistanceKernel(
ard=True,
features=["x_cat_1", "x_cat_2"],
),
batch_shape=torch.Size(),
ard_num_dims=10,
active_dims=list(range(5)),
features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]],
)


def test_map_multiple_kernels_on_feature_subsets():
fmap = {
"x_1": [0],
"x_2": [1],
"x_cat_1": [2, 3],
"x_cat_2": [4, 5],
}

k_mapped = kernels.map(
AdditiveKernel(
kernels=[
HammingDistanceKernel(
ard=True,
features=["x_cat_1", "x_cat_2"],
),
RBFKernel(
features=["x_1", "x_2"],
),
]
),
batch_shape=torch.Size(),
ard_num_dims=10,
active_dims=list(range(5)),
features_to_idx_mapper=lambda ks: [i for k in ks for i in fmap[k]],
)

assert len(k_mapped.kernels) == 2

assert isinstance(k_mapped.kernels[0], HammingKernelWithOneHots)
assert k_mapped.kernels[0].active_dims.tolist() == [2, 3, 4, 5]
assert k_mapped.kernels[0].ard_num_dims == 4

from gpytorch.kernels import RBFKernel as GpytorchRBFKernel

assert isinstance(k_mapped.kernels[1], GpytorchRBFKernel)
assert k_mapped.kernels[1].active_dims.tolist() == [0, 1]
assert k_mapped.kernels[1].ard_num_dims == 2

0 comments on commit 1867e7b

Please sign in to comment.