Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-46710: Disable jax in Gaussian process interp #392

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 5 additions & 151 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,10 @@
import copy
import treegp

import jax
from jax import jit
import jax.numpy as jnp

import logging

__all__ = [
"InterpolateOverDefectGaussianProcess",
"GaussianProcessJax",
"GaussianProcessTreegp",
]

Expand Down Expand Up @@ -62,7 +57,6 @@ def updateMaskFromArray(mask, bad_pixel, interpBit):
# TO DO --> might be better: mask.array[int(bad_pixel[:,1]-y0), int(bad_pixel[:,0]-x)] |= interpBit


@jit
def median_with_mad_clipping(data, mad_multiplier=2.0):
"""
Calculate the median of the input data after applying Median Absolute Deviation (MAD) clipping.
Expand Down Expand Up @@ -90,145 +84,14 @@ def median_with_mad_clipping(data, mad_multiplier=2.0):
>>> median_with_mad_clipping(data)
3.5
"""
median = jnp.median(data)
mad = jnp.median(jnp.abs(data - median))
median = np.median(data)
mad = np.median(np.abs(data - median))
clipping_range = mad_multiplier * mad
clipped_data = jnp.clip(data, median - clipping_range, median + clipping_range)
median_clipped = jnp.median(clipped_data)
clipped_data = np.clip(data, median - clipping_range, median + clipping_range)
median_clipped = np.median(clipped_data)
return median_clipped


@jit
def jax_rbf_kernel(x1, x2, sigma, correlation_length):
"""
Computes the radial basis function (RBF) kernel matrix.

Parameters:
-----------
x1 : `np.array`
Location of training data point with shape (n_samples, n_features).
x2 : `np.array`
Location of training/test data point with shape (n_samples, n_features).
sigma : `float`
The scale parameter of the kernel.
correlation_length : `float`
The correlation length parameter of the kernel.

Returns:
--------
kernel : `np.array`
RBF kernel matrix with shape (n_samples, n_samples).
"""
distance_squared = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2))
return kernel


@jit
def jax_get_alpha(y, kernel):
"""
Compute the alpha vector for Gaussian Process interpolation.

Parameters:
-----------
y : `np.array`
The target values of the Gaussian Process.
kernel : `np.array`
The kernel matrix of the Gaussian Process.

Returns:
--------
alpha : `np.array`
The alpha vector computed using the Cholesky decomposition and solution.

"""
factor = (jax.scipy.linalg.cholesky(kernel, overwrite_a=True, lower=False), False)
alpha = jax.scipy.linalg.cho_solve(factor, y, overwrite_b=False)
return alpha.reshape((len(alpha), 1))


@jit
def jax_get_gp_predict(kernel_rect, alpha):
"""
Compute the predicted values of gp using the given kernel and alpha (cholesky solution).

Parameters:
-----------
kernel_rect : `np.array`
The kernel matrix.
alpha : `np.array`
The alpha vector from Cholesky solution.

Returns:
--------
`np.array`
The predicted values of y.

"""
return jnp.dot(kernel_rect, alpha).T[0]


class GaussianProcessJax:
"""
Gaussian Process regression in JAX.
Kernel is assumed to be isotropic RBF kernel, and solved
using exact Cholesky decomposition.
The interpolation solution is obtained by solving the linear system:
y_interp = kernel_rect @ (kernel + y_err**2 * I)^-1 @ y_training.
See the Rasmussen and Williams book for more details.
Each function is decorated with @jit to compile the function.
Exist package like tinygp, that is implemented in jax also.
This class is a custom implementation
of Gaussian Processes, which allows for setting the hyperparameters,
fine-tuning the mean function,
and other specifications.

Parameters:
-----------
std : `float`, optional
Standard deviation of the Gaussian Process kernel. Default is 1.0.
correlation_length : `float`, optional
Correlation length of the Gaussian Process kernel. Default is 1.0.
white_noise : `float`, optional
White noise level of the Gaussian Process. Default is 0.0.
mean : `float`, optional
Mean value of the Gaussian Process. Default is 0.0.

"""

def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
self.std = std
self.correlation_length = correlation_length
self.white_noise = white_noise
self.mean = mean
self._alpha = None

# Looks weird to do that, but this is justified.
# in GP if no noise is provided, even if matrix
# can be inverted, it wont invert because of numerical
# issue (det(K)~0). Add a little bit of noise allow
# to compute a numerical solution in the case of no
# external noise is added. Wont happened on real
# image but help for unit test.
if self.white_noise == 0.0:
self.white_noise = 1e-5

def fit(self, x_train, y_train):
y = y_train - self.mean
self._x = x_train
kernel = jax_rbf_kernel(x_train, x_train, self.std, self.correlation_length)
y_err = jnp.ones(len(x_train[:, 0])) * self.white_noise
kernel += jnp.eye(len(y_err)) * (y_err**2)
self._alpha = jax_get_alpha(y, kernel)

def predict(self, x_predict):
kernel_rect = jax_rbf_kernel(
x_predict, self._x, self.std, self.correlation_length
)
y_pred = jax_get_gp_predict(kernel_rect, self._alpha)
return y_pred + self.mean


class GaussianProcessTreegp:
"""
Gaussian Process Treegp class for Gaussian Process interpolation.
Expand Down Expand Up @@ -313,8 +176,6 @@ class InterpolateOverDefectGaussianProcess:
The masked image containing the defects to be interpolated.
defects : `list`[`str`], optional
The types of defects to be interpolated. Default is ["SAT"].
method : `str`, optional
The method to use for GP interpolation. Must be either "jax" or "treegp". Default is "treegp".
fwhm : `float`, optional
The full width at half maximum (FWHM) of the PSF. Default is 5.
bin_spacing : `int`, optional
Expand All @@ -334,7 +195,6 @@ def __init__(
self,
masked_image,
defects=["SAT"],
method="treegp",
fwhm=5,
bin_image=True,
bin_spacing=10,
Expand All @@ -343,12 +203,6 @@ def __init__(
correlation_length_cut=5,
log=None,
):
if method == "jax":
self.GaussianProcess = GaussianProcessJax
elif method == "treegp":
self.GaussianProcess = GaussianProcessTreegp
else:
raise ValueError("Invalid method. Must be 'jax' or 'treegp'.")

self.log = log or logging.getLogger(__name__)

Expand Down Expand Up @@ -485,7 +339,7 @@ def interpolate_masked_sub_image(self, masked_sub_image):
# put this after binning as computing median is O(n*log(n))
clipped_median = median_with_mad_clipping(good_pixel[:, 2:])

gp = self.GaussianProcess(
gp = GaussianProcessTreegp(
std=np.sqrt(kernel_amplitude),
correlation_length=self.correlation_length,
white_noise=white_noise,
Expand Down
4 changes: 1 addition & 3 deletions tests/test_gp_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ def setUp(self):
# np.random.seed(12345)
# self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape)

@lsst.utils.tests.methodParameters(method=("jax"))
def test_interpolation(self, method: str):
def test_interpolation(self):
"""Test that the interpolation is done correctly.

Parameters
Expand All @@ -155,7 +154,6 @@ def test_interpolation(self, method: str):
gp = InterpolateOverDefectGaussianProcess(
self.maskedimage,
defects=["BAD", "SAT", "CR", "EDGE"],
method=method,
fwhm=self.correlation_length,
bin_image=False,
bin_spacing=30,
Expand Down
Loading