Skip to content

Commit 29640a9

Browse files
committed
clean redo of ProductKernel MTGP adjusments
1 parent bf3a70e commit 29640a9

File tree

5 files changed

+105
-80
lines changed

5 files changed

+105
-80
lines changed

botorch/sampling/pathwise/features/maps.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,14 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
125125
# Collect/scale individual feature blocks
126126
block = feature_map(x, **kwargs).to_dense()
127127
block_ndim = len(feature_map.output_shape)
128-
128+
129129
# Handle broadcasting for lower-dimensional feature maps
130130
if block_ndim < ndim:
131-
# Determine how the tiling/broadcasting works for lower-dimensional feature maps
131+
# Determine how the tiling/broadcasting works for lower-dimensional
132+
# feature maps
132133
tile_shape = shape[-ndim:-block_ndim]
133134
num_copies = prod(tile_shape)
134-
135+
135136
# Scale down by sqrt of number of copies to maintain proper variance
136137
if num_copies > 1:
137138
block = block * (num_copies**-0.5)
@@ -155,9 +156,11 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
155156
@property
156157
def raw_output_shape(self) -> Size:
157158
# Handle empty DirectSumFeatureMap case - can occur when:
158-
# 1. Purposely start with an empty container and plan to append feature maps later, or
159+
# 1. Purposely start with an empty container and plan to append feature
160+
# maps later, or
159161
# 2. Deleted the last entry and the list is now length-zero.
160-
# Returning Size([]) keeps the object in a queryable state until real feature maps are added.
162+
# Returning Size([]) keeps the object in a queryable state until real
163+
# feature maps are added.
161164
if not self:
162165
return Size([])
163166

@@ -217,12 +220,13 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
217220
for feature_map in self:
218221
block = feature_map(x, **kwargs)
219222
block_ndim = len(feature_map.output_shape)
220-
223+
221224
# Handle blocks that match the target dimensionality
222225
if block_ndim == ndim:
223226
# Convert LinearOperator to dense tensor if needed
224227
block = block.to_dense() if isinstance(block, LinearOperator) else block
225-
# Ensure block is in sparse format for efficient block diagonal construction
228+
# Ensure block is in sparse format for efficient block diagonal
229+
# construction
226230
block = block if block.is_sparse else block.to_sparse()
227231
else:
228232
# For lower-dimensional blocks, we need to expand dimensions
@@ -234,7 +238,7 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
234238
)
235239
block = block.to_dense()[multi_index]
236240
blocks.append(block)
237-
241+
238242
# Construct sparse block diagonal matrix from all blocks
239243
return sparse_block_diag(blocks, base_ndim=ndim)
240244

botorch/sampling/pathwise/prior_samplers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,11 @@ def _draw_kernel_feature_paths_MultiTaskGP(
157157
if not isinstance(model.covar_module, ProductKernel):
158158
# Fallback for non-ProductKernel cases (legacy support)
159159
import warnings
160+
160161
warnings.warn(
161-
f"MultiTaskGP with non-ProductKernel detected ({type(model.covar_module)}). "
162-
"Consider using ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.",
162+
f"MultiTaskGP with non-ProductKernel detected "
163+
f"({type(model.covar_module)}). Consider using "
164+
"ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.",
163165
UserWarning,
164166
)
165167
combined_kernel = model.covar_module
@@ -178,7 +180,8 @@ def _draw_kernel_feature_paths_MultiTaskGP(
178180
else:
179181
data_kernel = deepcopy(kernel)
180182
else:
181-
# If no active_dims on data kernel, add them so downstream helpers don't error
183+
# If no active_dims on data kernel, add them so downstream
184+
# helpers don't error
182185
data_kernel = deepcopy(kernel)
183186
data_kernel.active_dims = torch.LongTensor(
184187
[
@@ -202,8 +205,9 @@ def _draw_kernel_feature_paths_MultiTaskGP(
202205
# Ensure the data kernel was found
203206
if data_kernel is None:
204207
raise ValueError(
205-
f"Could not identify data kernel from ProductKernel. "
206-
"MTGPs should follow the standard ProductKernel(IndexKernel, SomeOtherKernel) pattern."
208+
"Could not identify data kernel from ProductKernel. "
209+
"MTGPs should follow the standard "
210+
"ProductKernel(IndexKernel, SomeOtherKernel) pattern."
207211
)
208212

209213
# Use the existing product kernel structure

botorch/sampling/pathwise/update_strategies.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,11 @@ def _draw_kernel_feature_paths_MultiTaskGP(
182182
# Fallback for non-ProductKernel cases (legacy support)
183183
# This should be rare as MTGPs typically use ProductKernels by definition
184184
import warnings
185+
185186
warnings.warn(
186-
f"MultiTaskGP with non-ProductKernel detected ({type(model.covar_module)}). "
187-
"Consider using ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.",
187+
f"MultiTaskGP with non-ProductKernel detected "
188+
f"({type(model.covar_module)}). Consider using "
189+
"ProductKernel(IndexKernel, SomeOtherKernel) for better compatibility.",
188190
UserWarning,
189191
)
190192
combined_kernel = model.covar_module
@@ -203,7 +205,8 @@ def _draw_kernel_feature_paths_MultiTaskGP(
203205
else:
204206
data_kernel = deepcopy(kernel)
205207
else:
206-
# If no active_dims on data kernel, add them so downstream helpers don't error
208+
# If no active_dims on data kernel, add them so downstream
209+
# helpers don't error
207210
data_kernel = deepcopy(kernel)
208211
data_kernel.active_dims = torch.LongTensor(
209212
[index for index in range(num_inputs) if index != task_index],
@@ -223,8 +226,9 @@ def _draw_kernel_feature_paths_MultiTaskGP(
223226
# Ensure data kernel was found
224227
if data_kernel is None:
225228
raise ValueError(
226-
f"Could not identify data kernel from ProductKernel. "
227-
"MTGPs should follow the standard ProductKernel(IndexKernel, SomeOtherKernel) pattern."
229+
"Could not identify data kernel from ProductKernel. "
230+
"MTGPs should follow the standard "
231+
"ProductKernel(IndexKernel, SomeOtherKernel) pattern."
228232
)
229233

230234
# Use the existing product kernel structure

website/docusaurus.config.js

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,47 @@ module.exports={
5252
"sidebarPath": "../website/sidebars.js",
5353
remarkPlugins: [remarkMath],
5454
rehypePlugins: [rehypeKatex],
55+
exclude: [
56+
"**/tutorials/bope/**",
57+
"**/tutorials/turbo_1/**",
58+
"**/tutorials/baxus/**",
59+
"**/tutorials/scalable_constrained_bo/**",
60+
"**/tutorials/saasbo/**",
61+
"**/tutorials/cost_aware_bayesian_optimization/**",
62+
"**/tutorials/Multi_objective_multi_fidelity_BO/**",
63+
"**/tutorials/bo_with_warped_gp/**",
64+
"**/tutorials/thompson_sampling/**",
65+
"**/tutorials/ibnn_bo/**",
66+
"**/tutorials/custom_model/**",
67+
"**/tutorials/multi_objective_bo/**",
68+
"**/tutorials/constrained_multi_objective_bo/**",
69+
"**/tutorials/robust_multi_objective_bo/**",
70+
"**/tutorials/decoupled_mobo/**",
71+
"**/tutorials/custom_acquisition/**",
72+
"**/tutorials/fit_model_with_torch_optimizer/**",
73+
"**/tutorials/compare_mc_analytic_acquisition/**",
74+
"**/tutorials/optimize_with_cmaes/**",
75+
"**/tutorials/optimize_stochastic/**",
76+
"**/tutorials/batch_mode_cross_validation/**",
77+
"**/tutorials/one_shot_kg/**",
78+
"**/tutorials/max_value_entropy/**",
79+
"**/tutorials/GIBBON_for_efficient_batch_entropy_search/**",
80+
"**/tutorials/risk_averse_bo_with_environmental_variables/**",
81+
"**/tutorials/risk_averse_bo_with_input_perturbations/**",
82+
"**/tutorials/constraint_active_search/**",
83+
"**/tutorials/information_theoretic_acquisition_functions/**",
84+
"**/tutorials/relevance_pursuit_robust_regression/**",
85+
"**/tutorials/meta_learning_with_rgpe/**",
86+
"**/tutorials/vae_mnist/**",
87+
"**/tutorials/multi_fidelity_bo/**",
88+
"**/tutorials/discrete_multi_fidelity_bo/**",
89+
"**/tutorials/composite_bo_with_hogp/**",
90+
"**/tutorials/composite_mtbo/**",
91+
"**/notebooks_community/clf_constrained_bo/**",
92+
"**/notebooks_community/hentropy_search/**",
93+
"**/notebooks_community/multi_source_bo/**",
94+
"**/notebooks_community/vbll_thompson_sampling/**"
95+
],
5596
},
5697
"blog": {},
5798
"theme": {

website/sidebars.js

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
const tutorials = () => {
9-
const allTutorialMetadata = require('./tutorials.json');
10-
const tutorialsSidebar = [{
11-
type: 'category',
12-
label: 'Tutorials',
13-
collapsed: false,
14-
items: [
15-
{
16-
type: 'doc',
17-
id: 'tutorials/index',
18-
label: 'Overview',
19-
},
20-
],
21-
},];
22-
for (var category in allTutorialMetadata) {
23-
const categoryItems = allTutorialMetadata[category];
24-
const items = [];
25-
categoryItems.map(item => {
26-
items.push({
27-
type: 'doc',
28-
label: item.title,
29-
id: `tutorials/${item.id}/index`,
30-
});
31-
});
32-
33-
tutorialsSidebar.push({
34-
type: 'category',
35-
label: category,
36-
items: items,
37-
});
38-
}
39-
return tutorialsSidebar;
40-
};
41-
42-
const notebooks_community = () => {
43-
const allNotebookItems = require('./notebooks_community.json');
44-
const items = [
45-
{
46-
type: 'doc',
47-
id: 'notebooks_community/index',
48-
label: 'Overview',
49-
},
50-
];
51-
allNotebookItems.map(item => {
52-
items.push({
53-
type: 'doc',
54-
label: item.title,
55-
id: `notebooks_community/${item.id}/index`,
56-
});
57-
});
58-
const notebooksSidebar = [{
59-
type: 'category',
60-
label: 'Community Notebooks',
61-
collapsed: false,
62-
items: items,
63-
},];
64-
return notebooksSidebar;
65-
};
66-
678
export default {
689
"docs": {
6910
"About": ["introduction", "design_philosophy", "botorch_and_ax", "papers"],
@@ -72,6 +13,37 @@ export default {
7213
"Advanced Topics": ["constraints", "objectives", "batching", "samplers"],
7314
"Multi-Objective Optimization": ["multi_objective"]
7415
},
75-
tutorials: tutorials(),
76-
"notebooks_community": notebooks_community(),
77-
}
16+
"tutorials": [
17+
{
18+
type: 'category',
19+
label: 'Tutorials',
20+
collapsed: false,
21+
items: [
22+
{
23+
type: 'doc',
24+
id: 'tutorials/index',
25+
label: 'Overview',
26+
},
27+
{
28+
type: 'doc',
29+
id: 'tutorials/closed_loop_botorch_only/index',
30+
label: 'Closed Loop BoTorch Only',
31+
},
32+
],
33+
},
34+
],
35+
"notebooks_community": [
36+
{
37+
type: 'category',
38+
label: 'Community Notebooks',
39+
collapsed: false,
40+
items: [
41+
{
42+
type: 'doc',
43+
id: 'notebooks_community/index',
44+
label: 'Overview',
45+
},
46+
],
47+
},
48+
],
49+
}

0 commit comments

Comments
 (0)