Skip to content

Commit bf3a70e

Browse files
committed
Clean redo of ProductKernel MTGP adjustments
1 parent ebe03af commit bf3a70e

File tree

3 files changed

+64
-23
lines changed

3 files changed

+64
-23
lines changed

botorch/sampling/pathwise/features/maps.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,28 +122,42 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
122122
shape = self.raw_output_shape
123123
ndim = len(shape)
124124
for feature_map in self:
125+
# Collect/scale individual feature blocks
125126
block = feature_map(x, **kwargs).to_dense()
126127
block_ndim = len(feature_map.output_shape)
128+
129+
# Handle broadcasting for lower-dimensional feature maps
127130
if block_ndim < ndim:
131+
# Determine how the tiling/broadcasting works for lower-dimensional feature maps
128132
tile_shape = shape[-ndim:-block_ndim]
129133
num_copies = prod(tile_shape)
134+
135+
# Scale down by sqrt of number of copies to maintain proper variance
130136
if num_copies > 1:
131137
block = block * (num_copies**-0.5)
132138

139+
# Create multi-index for broadcasting: add None dimensions for tiling
140+
# This expands the block to match the target dimensionality
133141
multi_index = (
134142
...,
135-
*repeat(None, ndim - block_ndim),
136-
*repeat(slice(None), block_ndim),
143+
*repeat(None, ndim - block_ndim), # Add new axes for tiling
144+
*repeat(slice(None), block_ndim), # Keep existing dimensions
137145
)
146+
# Apply the multi-index and expand to tile across the new dimensions
138147
block = block[multi_index].expand(
139148
*block.shape[:-block_ndim], *tile_shape, *block.shape[-block_ndim:]
140149
)
141150
blocks.append(block)
142151

152+
# Concatenate all blocks along the last dimension
143153
return torch.concat(blocks, dim=-1)
144154

145155
@property
146156
def raw_output_shape(self) -> Size:
157+
# Handle empty DirectSumFeatureMap case - can occur when:
158+
# 1. Purposely start with an empty container and plan to append feature maps later, or
159+
# 2. Deleted the last entry and the list is now length-zero.
160+
# Returning Size([]) keeps the object in a queryable state until real feature maps are added.
147161
if not self:
148162
return Size([])
149163

@@ -203,17 +217,25 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
203217
for feature_map in self:
204218
block = feature_map(x, **kwargs)
205219
block_ndim = len(feature_map.output_shape)
220+
221+
# Handle blocks that match the target dimensionality
206222
if block_ndim == ndim:
223+
# Convert LinearOperator to dense tensor if needed
207224
block = block.to_dense() if isinstance(block, LinearOperator) else block
225+
# Ensure block is in sparse format for efficient block diagonal construction
208226
block = block if block.is_sparse else block.to_sparse()
209227
else:
228+
# For lower-dimensional blocks, we need to expand dimensions
229+
# but keep them dense since sparse tensor broadcasting is limited
210230
multi_index = (
211231
...,
212-
*repeat(None, ndim - block_ndim),
213-
*repeat(slice(None), block_ndim),
232+
*repeat(None, ndim - block_ndim), # Add new axes for expansion
233+
*repeat(slice(None), block_ndim), # Keep existing dimensions
214234
)
215235
block = block.to_dense()[multi_index]
216236
blocks.append(block)
237+
238+
# Construct sparse block diagonal matrix from all blocks
217239
return sparse_block_diag(blocks, base_ndim=ndim)
218240

219241

botorch/sampling/pathwise/prior_samplers.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,20 @@ def _draw_kernel_feature_paths_MultiTaskGP(
150150
)
151151

152152
# Extract kernels from the product kernel structure
153-
# model.covar_module is a ProductKernel
153+
# model.covar_module is a ProductKernel by definition for MTGPs
154154
# containing data_covar_module * task_covar_module
155155
from gpytorch.kernels import ProductKernel
156156

157-
if isinstance(model.covar_module, ProductKernel):
157+
if not isinstance(model.covar_module, ProductKernel):
158+
# Fallback for non-ProductKernel cases (legacy support)
159+
import warnings
160+
warnings.warn(
161+
f"MultiTaskGP with non-ProductKernel detected ({type(model.covar_module)}). "
162+
"Consider using ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.",
163+
UserWarning,
164+
)
165+
combined_kernel = model.covar_module
166+
else:
158167
# Get the individual kernels from the product kernel
159168
kernels = model.covar_module.kernels
160169

@@ -169,7 +178,7 @@ def _draw_kernel_feature_paths_MultiTaskGP(
169178
else:
170179
data_kernel = deepcopy(kernel)
171180
else:
172-
# If no active_dims, it's likely the data kernel
181+
# If no active_dims on data kernel, add them so downstream helpers don't error
173182
data_kernel = deepcopy(kernel)
174183
data_kernel.active_dims = torch.LongTensor(
175184
[
@@ -180,7 +189,7 @@ def _draw_kernel_feature_paths_MultiTaskGP(
180189
device=data_kernel.device,
181190
)
182191

183-
# If we couldn't find the task kernel, create it based on the structure
192+
# If the task kernel can't be found, create it based on the structure
184193
if task_kernel is None:
185194
from gpytorch.kernels import IndexKernel
186195

@@ -190,14 +199,15 @@ def _draw_kernel_feature_paths_MultiTaskGP(
190199
active_dims=[task_index],
191200
).to(device=model.covar_module.device, dtype=model.covar_module.dtype)
192201

193-
# Set task kernel active dims correctly
194-
task_kernel.active_dims = torch.tensor([task_index], device=task_kernel.device)
202+
# Ensure the data kernel was found
203+
if data_kernel is None:
204+
raise ValueError(
205+
f"Could not identify data kernel from ProductKernel. "
206+
"MTGPs should follow the standard ProductKernel(IndexKernel, SomeOtherKernel) pattern."
207+
)
195208

196209
# Use the existing product kernel structure
197210
combined_kernel = data_kernel * task_kernel
198-
else:
199-
# Fallback to using the original covar_module directly
200-
combined_kernel = model.covar_module
201211

202212
return _draw_kernel_feature_paths_fallback(
203213
mean_module=model.mean_module,

botorch/sampling/pathwise/update_strategies.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,21 @@ def _draw_kernel_feature_paths_MultiTaskGP(
174174
)
175175

176176
# Extract kernels from the product kernel structure
177-
# model.covar_module is a ProductKernel
177+
# model.covar_module is a ProductKernel by definition for MTGPs
178178
# containing data_covar_module * task_covar_module
179179
from gpytorch.kernels import ProductKernel
180180

181-
if isinstance(model.covar_module, ProductKernel):
181+
if not isinstance(model.covar_module, ProductKernel):
182+
# Fallback for non-ProductKernel cases (legacy support)
183+
# This should be rare as MTGPs typically use ProductKernels by definition
184+
import warnings
185+
warnings.warn(
186+
f"MultiTaskGP with non-ProductKernel detected ({type(model.covar_module)}). "
187+
"Consider using ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.",
188+
UserWarning,
189+
)
190+
combined_kernel = model.covar_module
191+
else:
182192
# Get the individual kernels from the product kernel
183193
kernels = model.covar_module.kernels
184194

@@ -193,7 +203,7 @@ def _draw_kernel_feature_paths_MultiTaskGP(
193203
else:
194204
data_kernel = deepcopy(kernel)
195205
else:
196-
# If no active_dims, it's likely the data kernel
206+
# If no active_dims on data kernel, add them so downstream helpers don't error
197207
data_kernel = deepcopy(kernel)
198208
data_kernel.active_dims = torch.LongTensor(
199209
[index for index in range(num_inputs) if index != task_index],
@@ -210,16 +220,15 @@ def _draw_kernel_feature_paths_MultiTaskGP(
210220
active_dims=[task_index],
211221
).to(device=model.covar_module.device, dtype=model.covar_module.dtype)
212222

213-
# Set task kernel active dims correctly
214-
task_kernel.active_dims = torch.LongTensor(
215-
[task_index], device=task_kernel.device
216-
)
223+
# Ensure data kernel was found
224+
if data_kernel is None:
225+
raise ValueError(
226+
f"Could not identify data kernel from ProductKernel. "
227+
"MTGPs should follow the standard ProductKernel(IndexKernel, SomeOtherKernel) pattern."
228+
)
217229

218230
# Use the existing product kernel structure
219231
combined_kernel = data_kernel * task_kernel
220-
else:
221-
# Fallback to using the original covar_module directly
222-
combined_kernel = model.covar_module
223232

224233
# Return exact update using product kernel
225234
return _gaussian_update_exact(

0 commit comments

Comments
 (0)