Skip to content

Refactor botorch/sampling/pathwise and add support for product kernels #2838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

seashoo
Copy link

@seashoo seashoo commented May 5, 2025

Motivation

Hi! I'm Sahran Ashoor, an undergraduate research assistant working for the Uncertainty Quantification Lab at the University of Houston. I work under Dr. Ruda Zhang and Taiwo Adebiyi, both of whom having already spoken with Max Balandat regarding incorporating a rebase of botorch/sampling/pathwise (Largely written by James. T. Wilson). The changes included in this pull request are my best attempt at faithfully completing the change logs I was provided (product_kernel_diff.txt).

Have you read the Contributing Guidelines on pull requests?

Yes!

Project Overview

The primary goal was to make the original codebase by Wilson compatible with the latest BoTorch version. To achieve this, we used the original source codes and test suites, which initially revealed several incompatibility issues. Our main contribution involved carefully rebasing Wilson's code while preserving the logic for pathwise sampling. Importantly, all changes were confined to the botorch/sampling/pathwise directory to ensure a seamless integration, passing both local pathwise test suites and BoTorch's global test suites via GitHub workflows.

In terms of code logic, we relied on Wilson's unit tests for prior, updates, and posterior sampling, which we believe are sufficient to validate the correctness of the implementation. However, we welcome your feedback on this approach, and would appreciate any suggestions for additional tests or example scripts to further confirm the robustness of the changes. We are open to collaborating further on this effort.

Test Plan

(Write your test plan here. If you changed any code, please provide us with clear instructions on how you verified your changes work. Bonus points for screenshots and videos!)

The entirety of the testing suite was ran through pytest. Through additional verification we've found that the logic may be offset, but we're hoping to work with you all and further validate these changes under the discretion of Dr. Zhang. Expect further communications directly from my lab that will provide more insight into the rebase.

Related PRs

(If this PR adds or changes functionality, please take some time to update the docs at https://github.com/pytorch/botorch, and link to your PR here.)

N/A

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label May 5, 2025
Copy link

codecov bot commented May 5, 2025

Codecov Report

Attention: Patch coverage is 92.69341% with 51 lines in your changes missing coverage. Please review.

Project coverage is 99.70%. Comparing base (a268631) to head (0321b49).
Report is 28 commits behind head on main.

Files with missing lines Patch % Lines
botorch/sampling/pathwise/features/maps.py 89.55% 21 Missing ⚠️
botorch/sampling/pathwise/features/generators.py 83.60% 10 Missing ⚠️
botorch/sampling/pathwise/update_strategies.py 84.61% 8 Missing ⚠️
botorch/sampling/pathwise/utils/helpers.py 95.83% 6 Missing ⚠️
botorch/utils/types.py 66.66% 3 Missing ⚠️
botorch/sampling/pathwise/utils/mixins.py 97.22% 2 Missing ⚠️
botorch/sampling/pathwise/paths.py 97.50% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##              main    #2838      +/-   ##
===========================================
- Coverage   100.00%   99.70%   -0.30%     
===========================================
  Files          211      214       +3     
  Lines        19397    19824     +427     
===========================================
+ Hits         19397    19766     +369     
- Misses           0       58      +58     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@Balandat
Copy link
Contributor

Balandat commented May 6, 2025

Thanks @seashoo for the PR - this is a big one! It'll take me a bit of time to review this in detail, I plan to do a first higher-level pass this week.

Through additional verification we've found that the logic may be offset

What exactly does this mean?

@TaiwoAdebiyi23
Copy link

Thanks @seashoo for the PR - this is a big one! It'll take me a bit of time to review this in detail, I plan to do a first higher-level pass this week.

Through additional verification we've found that the logic may be offset

What exactly does this mean?

Hi @Balandat,

Thanks for the response! We've included a more detailed Project Overview section in the pull request description to clarify our validation approach. Specifically, we utilized the existing unit test files, which cover prior, updates, and posterior sampling, and ensured that all tests passed as part of this rebase. While these tests are comprehensive, we welcome any additional guidance you might have on further validating the code's robustness.

@Balandat Balandat changed the title Rebase for botorch/sampling/pathwise - Dr. Ruda Zhang, Taiwo Adebiyi, Sahran Ashoor - University of Houson Refactor botorch/sampling/pathwise and add support for product kernels May 18, 2025
Copy link
Contributor

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went over the main code in the PR in detail; overall this looks great, thanks for the effort and patching some gaps (e.g. _gaussian_update_ModelListGP). I have not reviewed the testing code in detail, but can do that after the next pass.

The key things to address are:

  1. Some additions in the patch file were not included here - curious to understand why (and if this was an oversight let's add them in - I pointed out which ones).
  2. Currently the tests still have some coverage gaps based on the codecov report here. Please add some test cases to also cover the currently uncovered lines.

Comment on lines 7 to 15
r"""
.. [rahimi2007random]
A. Rahimi and B. Recht. Random features for large-scale kernel machines.
Advances in Neural Information Processing Systems 20 (2007).

.. [sutherland2015error]
D. J. Sutherland and J. Schneider. On the error of random Fourier features.
arXiv preprint arXiv:1506.02785 (2015).
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove these references?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My apologies, this was an oversight in a cleanup commit where I aimed to procedurally remove some cluttered developer comments. I've restored this docstring and added a few more - this is further addressed in your other comment.

# generators.
# It defines a callable that takes a kernel and dimension parameters and returns a
# KernelFeatureMap.
TKernelFeatureMapGenerator = Callable[[kernels.Kernel, int, int], KernelFeatureMap]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding these comments, they're very helpful

num_inputs: int,
num_outputs: int,
num_random_features: int = 1024,
num_ambient_inputs: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We general use this PEP 604 style type definition - let's update this throughout the code here.

Suggested change
num_ambient_inputs: Optional[int] = None,
num_ambient_inputs: int | None = None,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out. I’ve implemented PEP 604 definitions throughout pathwise.

Comment on lines 50 to 55
# IMPLEMENTATION NOTE: This function serves as the main entry point for generating
# feature maps from kernels. It uses the dispatcher to call the appropriate handler
# based on the kernel type. The function has been updated from the original
# implementation
# to use more descriptive parameter names (num_ambient_inputs instead of num_inputs,
# and num_random_features instead of num_outputs) to better reflect their purpose.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please keep the docstring format the same as in the previous code (this is also contained in product_kernel_diff.txt). This applies throughout the code

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve replaced many of the large comment instances with the intended Google-style docstrings. I left a few additional docstrings where I thought they might be helpful.

device=kernel.device,
dtype=kernel.dtype,
)
output_transforms = [transforms.SineCosineTransform(constant)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be the same as the behavior in the patch file? Relevant code snippet of what I think this should look like:

    output_transforms = [transforms.ConstantMulTransform(constant)]
    if cosine_only:
        bias = 2 * pi * torch.rand(num_random_features, **tkwargs)
        num_raw_features = num_random_features
        output_transforms.append(transforms.CosineTransform())
    elif num_random_features % 2:
        raise UnsupportedError(
            f"Expected an even number of random features, but {num_random_features=}."
        )
    else:
        bias = None
        num_raw_features = num_random_features // 2
        output_transforms.append(transforms.SineCosineTransform())

Copy link
Author

@seashoo seashoo Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. The original SineCosineTransform(constant) approach was too rigid since it baked scaling into the transform class. I've implemented the conditional logic you described. Thanks for the heads up!

noise_values = torch.randn_like(sample_values).unsqueeze(-1)
noise_values = noise_covariance.cholesky() @ noise_values
sample_values = sample_values + noise_values.squeeze(-1)
# Generate noise values with correct shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing this

Comment on lines +164 to +168
task_index = (
num_inputs + model._task_feature
if model._task_feature < 0
else model._task_feature
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to do this so we can always assume that it's positive and we don't have to do this custom handling. But I think this is ok for now as is, that change is beyond the scope of this PR.

Suggested change
task_index = (
num_inputs + model._task_feature
if model._task_feature < 0
else model._task_feature
)
# TODO: Changed `MultiTaskGP` to normalize the task feature in its constructor.
task_index = (
num_inputs + model._task_feature
if model._task_feature < 0
else model._task_feature
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion- I would be willing to come back to this in a separate PR!



@GaussianUpdate.register(ModelListGP, LikelihoodList)
def _gaussian_update_ModelListGP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is completely new, right? Very nice!

class SineCosineTransform(TensorTransform):
r"""A transform that returns concatenated sine and cosine features."""

def __init__(self, scale: Optional[Tensor] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, scale: Optional[Tensor] = None):
def __init__(self, scale: Tensor | None = None):

Comment on lines 18 to 22
# Removed unused imports
# from botorch.sampling.pathwise.utils.transforms import (
# ChainedTransform,
# FeatureSelector
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if unused let's delete them outright - applies throughout the code

Suggested change
# Removed unused imports
# from botorch.sampling.pathwise.utils.transforms import (
# ChainedTransform,
# FeatureSelector
# )

@seashoo
Copy link
Author

seashoo commented Jul 29, 2025

@Balandat, thank you for the comprehensive review and detailed feedback on Wilson's product_kernel.diff implementation! My apologies for the time it's taken to get back to you- some of the implementations took quite some time to fully realize and I've been balancing the work alongside my current internship.

I'm excited to collaborate with you at a faster pace now that I've freed up time! If you have questions regarding any of my specific implementations, feel free to ask in a reply to any of the comments here- I'll be able to communicate much more swiftly now. I've went ahead and resolved some of the upstream merge conflicts that appeared while I was away, and I've also filled up the code coverage gaps as you've asked.

Here's a quick summary of the major changes implemented to address your concerns:

Mathematical Issues Resolved

Product Kernel Implementation

Completely redesigned the product_feature_generator which was failing with relative errors around 0.75 vs tolerance ~0.094

  • Separated finite-dimensional from infinite-dimensional sub-kernels
  • Used Hadamard products for infinite-dimensional kernel combinations
  • Applied outer products to merge finite-dimensional with combined infinite-dimensional kernels
  • Automatically enabled cosine_only=True for multiple infinite-dimensional kernels to avoid problematic tensor products

Transform System Overhaul

Fixed scaling and coordination issues across the transform pipeline

  • SineCosineTransform: Added conditional logic instead of rigid constant scaling
  • Parameter Passing: Improved explicit num_ambient_inputs handling vs complex kwargs manipulation
  • Transform Chaining: Enhanced append_transform utility to properly handle None cases

Architectural Improvements

Feature Map Redesign

Built comprehensive feature map architecture

  • DirectSumFeatureMap: Rewrote raw_output_shape method to handle mixed dimensionality without MagicMock detection
  • SparseDirectSumFeatureMap: Implemented for completeness, available for manual use
  • HadamardProductFeatureMap/OuterProductFeatureMap: Enhanced for proper kernel composition

Code Organization

Restructured from monolithic utils.py into modular package

  • Split into dedicated helpers, mixins, and transforms modules
  • Improved dispatcher system with specific return types (MultitaskKernelFeatureMap, DirectSumFeatureMap, etc.)
  • Fixed sub_kernels parameter usage for LCM kernel compatibility

Code Quality Enhancements

Modern Python Standards

  • Adopted PEP 604 union syntax (A | B vs Union[A, B]) throughout pathwise directory
  • Replaced large comment blocks with Google-style docstrings
  • Cleaned up redundant imports and unused attributes

Type Safety

Enhanced type annotations and return type specificity for better IDE support

Technical details regarding the issues + approaches taken in response to your suggestions are further addressed in my replies!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants