Skip to content

Commit d0314bc

Browse files
committed
PR response changes pt. 1 (mainly ProductKernel assertion adjustments)
1 parent ebe03af commit d0314bc

File tree

5 files changed

+101
-105
lines changed

5 files changed

+101
-105
lines changed

botorch/sampling/pathwise/features/maps.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,45 @@ def __init__(
118118
self.output_transform = output_transform
119119

120120
def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
121-
blocks = []
122-
shape = self.raw_output_shape
123-
ndim = len(shape)
121+
# Collect the (possibly broadcasted) feature tensors in `blocks`.
122+
# ``self.raw_output_shape`` encodes the *final* shape of the concatenated
123+
# feature map. For each individual ``feature_map`` we therefore need to
124+
# (1) obtain its dense representation, (2) broadcast it so that its
125+
# trailing dimensions match ``self.raw_output_shape`` and (3) rescale it
126+
# if we replicate the same tensor multiple times (to keep ‖ϕ‖ roughly
127+
# invariant).
128+
129+
blocks: list[Tensor] = []
130+
131+
shape = self.raw_output_shape # target output shape
132+
ndim = len(shape) # #feature dimensions incl. batch
133+
124134
for feature_map in self:
135+
# 1. Evaluate (dense) features for the current sub-map.
125136
block = feature_map(x, **kwargs).to_dense()
137+
126138
block_ndim = len(feature_map.output_shape)
139+
140+
# 2. If this map has fewer *feature* dimensions than the direct sum
141+
# (e.g. vector-valued sub-map in a matrix-valued direct-sum) we have
142+
# to *tile* it along the missing leading feature dimensions so that
143+
# shapes line up for concatenation.
127144
if block_ndim < ndim:
145+
# ``tile_shape`` tells us how many copies we need along every
146+
# missing feature dimension (could be >1 for e.g. Kronecker sums).
128147
tile_shape = shape[-ndim:-block_ndim]
148+
149+
# Rescale by 1/√k when we replicate the same block *k* times to
150+
# avoid artificially inflating its norm (motivated by the fact
151+
# that direct sums of orthogonal features preserve inner-products
152+
# only up to such a scaling).
129153
num_copies = prod(tile_shape)
130154
if num_copies > 1:
131-
block = block * (num_copies**-0.5)
155+
block = block * (num_copies ** -0.5)
132156

157+
# ``multi_index`` inserts ``None`` (i.e. `None` in slice syntax)
158+
# so that broadcasting expands the tensor along the new axes
159+
# without additional memory allocations.
133160
multi_index = (
134161
...,
135162
*repeat(None, ndim - block_ndim),
@@ -138,12 +165,17 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
138165
block = block[multi_index].expand(
139166
*block.shape[:-block_ndim], *tile_shape, *block.shape[-block_ndim:]
140167
)
168+
169+
# 3. Append the (now correctly shaped) block to be concatenated later.
141170
blocks.append(block)
142171

172+
# Concatenate along the *last* axis (feature dimension).
143173
return torch.concat(blocks, dim=-1)
144174

145175
@property
146176
def raw_output_shape(self) -> Size:
177+
# If the container is empty (e.g. DirectSumFeatureMap([])), treat the
178+
# output as 0-D until feature maps are added.
147179
if not self:
148180
return Size([])
149181

@@ -204,13 +236,23 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
204236
block = feature_map(x, **kwargs)
205237
block_ndim = len(feature_map.output_shape)
206238
if block_ndim == ndim:
239+
# Case 1: this sub-map already has the *max* feature-rank we’re
240+
# going to emit. We simply make sure it is stored sparsely:
241+
# – Convert `LinearOperator` → dense so that `.to_sparse()`
242+
# is available.
243+
# – If it is still dense, call `.to_sparse()`; otherwise keep
244+
# the sparse representation it already has.
207245
block = block.to_dense() if isinstance(block, LinearOperator) else block
208246
block = block if block.is_sparse else block.to_sparse()
209247
else:
248+
# Case 2: lower-rank feature-map. Bring it up to `ndim` by
249+
# slicing with `None` (adds singleton axes) so broadcasting will
250+
# later expand it. We stay dense here because we’ll stuff the
251+
# result into a block-diag sparse matrix at the very end.
210252
multi_index = (
211253
...,
212-
*repeat(None, ndim - block_ndim),
213-
*repeat(slice(None), block_ndim),
254+
*repeat(None, ndim - block_ndim), # adds missing dims
255+
*repeat(slice(None), block_ndim), # keep existing dims
214256
)
215257
block = block.to_dense()[multi_index]
216258
blocks.append(block)

botorch/sampling/pathwise/prior_samplers.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -149,55 +149,35 @@ def _draw_kernel_feature_paths_MultiTaskGP(
149149
else model._task_feature
150150
)
151151

152-
# Extract kernels from the product kernel structure
153-
# model.covar_module is a ProductKernel
154-
# containing data_covar_module * task_covar_module
152+
# MultiTaskGP *always* wraps data_covar_module and task_covar_module in a
153+
# ProductKernel (see MTGP implementation). If that invariant is violated we
154+
# raise an error rather than silently guessing how to proceed.
155+
155156
from gpytorch.kernels import ProductKernel
156157

157-
if isinstance(model.covar_module, ProductKernel):
158-
# Get the individual kernels from the product kernel
159-
kernels = model.covar_module.kernels
160-
161-
# Find data and task kernels based on their active_dims
162-
data_kernel = None
163-
task_kernel = None
164-
165-
for kernel in kernels:
166-
if hasattr(kernel, "active_dims") and kernel.active_dims is not None:
167-
if task_index in kernel.active_dims:
168-
task_kernel = deepcopy(kernel)
169-
else:
170-
data_kernel = deepcopy(kernel)
171-
else:
172-
# If no active_dims, it's likely the data kernel
173-
data_kernel = deepcopy(kernel)
174-
data_kernel.active_dims = torch.LongTensor(
175-
[
176-
index
177-
for index in range(train_X.shape[-1])
178-
if index != task_index
179-
],
180-
device=data_kernel.device,
181-
)
182-
183-
# If we couldn't find the task kernel, create it based on the structure
184-
if task_kernel is None:
185-
from gpytorch.kernels import IndexKernel
186-
187-
task_kernel = IndexKernel(
188-
num_tasks=model.num_tasks,
189-
rank=model._rank,
190-
active_dims=[task_index],
191-
).to(device=model.covar_module.device, dtype=model.covar_module.dtype)
192-
193-
# Set task kernel active dims correctly
194-
task_kernel.active_dims = torch.tensor([task_index], device=task_kernel.device)
195-
196-
# Use the existing product kernel structure
197-
combined_kernel = data_kernel * task_kernel
198-
else:
199-
# Fallback to using the original covar_module directly
200-
combined_kernel = model.covar_module
158+
if not isinstance(model.covar_module, ProductKernel):
159+
raise RuntimeError(
160+
"Expected `model.covar_module` to be a ProductKernel (data × task), "
161+
"but found {type(model.covar_module).__name__}. If you are wrapping "
162+
"kernels manually please combine them with gpytorch.kernels.ProductKernel "
163+
"so the path-wise utilities can reason about the structure."
164+
)
165+
166+
# The product already represents data_kernel * task_kernel; we can pass it
167+
# straight through to downstream routines.
168+
combined_kernel = model.covar_module
169+
170+
# Ensure the data kernel inside the product has `active_dims` set; this is
171+
# required downstream by `get_kernel_num_inputs`. MTGPs created via the
172+
# public constructor already do this, but if a user manually overwrote the
173+
# `covar_module` we may need to patch it up here.
174+
kernels = combined_kernel.kernels # type: ignore[attr-defined]
175+
for k in kernels:
176+
if getattr(k, "active_dims", None) is None:
177+
k.active_dims = torch.LongTensor(
178+
[idx for idx in range(num_ambient_inputs) if idx != task_index],
179+
device=k.device,
180+
)
201181

202182
return _draw_kernel_feature_paths_fallback(
203183
mean_module=model.mean_module,

botorch/sampling/pathwise/update_strategies.py

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -164,62 +164,34 @@ def _draw_kernel_feature_paths_MultiTaskGP(
164164
if noise_covariance is None:
165165
noise_covariance = likelihood.noise_covar(shape=points.shape[:-1])
166166

167-
# Prepare product kernel
167+
# Determine total input dimensionality and identify the task-feature index.
168168
num_inputs = points.shape[-1]
169-
# TODO: Changed `MultiTaskGP` to normalize the task feature in its constructor.
170169
task_index = (
171-
num_inputs + model._task_feature
172-
if model._task_feature < 0
173-
else model._task_feature
170+
num_inputs + model._task_feature if model._task_feature < 0 else model._task_feature
174171
)
175172

176-
# Extract kernels from the product kernel structure
177-
# model.covar_module is a ProductKernel
178-
# containing data_covar_module * task_covar_module
173+
# MTGP should always provide a ProductKernel = data × task. Enforce that
174+
# contract and surface actionable feedback if it is violated.
175+
179176
from gpytorch.kernels import ProductKernel
180177

181-
if isinstance(model.covar_module, ProductKernel):
182-
# Get the individual kernels from the product kernel
183-
kernels = model.covar_module.kernels
184-
185-
# Find data and task kernels based on their active_dims
186-
data_kernel = None
187-
task_kernel = None
188-
189-
for kernel in kernels:
190-
if hasattr(kernel, "active_dims") and kernel.active_dims is not None:
191-
if task_index in kernel.active_dims:
192-
task_kernel = deepcopy(kernel)
193-
else:
194-
data_kernel = deepcopy(kernel)
195-
else:
196-
# If no active_dims, it's likely the data kernel
197-
data_kernel = deepcopy(kernel)
198-
data_kernel.active_dims = torch.LongTensor(
199-
[index for index in range(num_inputs) if index != task_index],
200-
device=data_kernel.device,
201-
)
202-
203-
# If we couldn't find the task kernel, create it based on the structure
204-
if task_kernel is None:
205-
from gpytorch.kernels import IndexKernel
206-
207-
task_kernel = IndexKernel(
208-
num_tasks=model.num_tasks,
209-
rank=model._rank,
210-
active_dims=[task_index],
211-
).to(device=model.covar_module.device, dtype=model.covar_module.dtype)
212-
213-
# Set task kernel active dims correctly
214-
task_kernel.active_dims = torch.LongTensor(
215-
[task_index], device=task_kernel.device
178+
if not isinstance(model.covar_module, ProductKernel):
179+
raise RuntimeError(
180+
"MultiTaskGP `covar_module` is expected to be a ProductKernel (data × task) "
181+
f"but found {type(model.covar_module).__name__}. If you build a custom "
182+
"MTGP variant please wrap the two kernels with gpytorch.kernels.ProductKernel."
216183
)
217184

218-
# Use the existing product kernel structure
219-
combined_kernel = data_kernel * task_kernel
220-
else:
221-
# Fallback to using the original covar_module directly
222-
combined_kernel = model.covar_module
185+
combined_kernel = model.covar_module
186+
187+
# Ensure the data part of the product kernel has `active_dims` set; required
188+
# by downstream helpers when calculating input dimensionality.
189+
kernels = combined_kernel.kernels # type: ignore[attr-defined]
190+
for k in kernels:
191+
if getattr(k, "active_dims", None) is None:
192+
k.active_dims = torch.LongTensor(
193+
[idx for idx in range(num_inputs) if idx != task_index], device=k.device
194+
)
223195

224196
# Return exact update using product kernel
225197
return _gaussian_update_exact(

test/sampling/pathwise/test_prior_samplers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,9 @@ def test_multitask_gp_kernel_handling(self):
234234
simple_kernel = RBFKernel(ard_num_dims=3)
235235
model2.covar_module = simple_kernel # Non-ProductKernel
236236

237-
paths2 = draw_kernel_feature_paths(model2, sample_shape=Size([1]))
238-
self.assertIsNotNone(paths2)
237+
import pytest
238+
with pytest.raises(RuntimeError):
239+
draw_kernel_feature_paths(model2, sample_shape=Size([1]))
239240

240241
# Test kernel without active_dims to trigger active_dims assignment
241242
model3 = MultiTaskGP(train_X=train_X, train_Y=train_Y, task_feature=2)

test/sampling/pathwise/test_update_strategies.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,9 @@ def test_multitask_gp_kernel_handling(self):
311311
simple_kernel = RBFKernel(ard_num_dims=3)
312312
model2.covar_module = simple_kernel # Non-ProductKernel
313313

314-
update_paths2 = gaussian_update(model=model2, sample_values=sample_values)
315-
self.assertIsNotNone(update_paths2)
314+
import pytest
315+
with pytest.raises(RuntimeError):
316+
gaussian_update(model=model2, sample_values=sample_values)
316317

317318
# Test kernel without active_dims to trigger active_dims assignment
318319
model3 = MultiTaskGP(train_X=train_X, train_Y=train_Y, task_feature=2)

0 commit comments

Comments
 (0)