Skip to content

Commit

Permalink
Merge pull request #392 from lsst/tickets/DM-46710
Browse files Browse the repository at this point in the history
DM-46710: Disable jax in Gaussian process interp
  • Loading branch information
PFLeget authored Oct 9, 2024
2 parents ebc1f9d + 37a273c commit c081298
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 154 deletions.
156 changes: 5 additions & 151 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,10 @@
import copy
import treegp

import jax
from jax import jit
import jax.numpy as jnp

import logging

__all__ = [
"InterpolateOverDefectGaussianProcess",
"GaussianProcessJax",
"GaussianProcessTreegp",
]

Expand Down Expand Up @@ -62,7 +57,6 @@ def updateMaskFromArray(mask, bad_pixel, interpBit):
# TO DO --> might be better: mask.array[int(bad_pixel[:,1]-y0), int(bad_pixel[:,0]-x)] |= interpBit


@jit
def median_with_mad_clipping(data, mad_multiplier=2.0):
"""
Calculate the median of the input data after applying Median Absolute Deviation (MAD) clipping.
Expand Down Expand Up @@ -90,145 +84,14 @@ def median_with_mad_clipping(data, mad_multiplier=2.0):
>>> median_with_mad_clipping(data)
3.5
"""
median = jnp.median(data)
mad = jnp.median(jnp.abs(data - median))
median = np.median(data)
mad = np.median(np.abs(data - median))
clipping_range = mad_multiplier * mad
clipped_data = jnp.clip(data, median - clipping_range, median + clipping_range)
median_clipped = jnp.median(clipped_data)
clipped_data = np.clip(data, median - clipping_range, median + clipping_range)
median_clipped = np.median(clipped_data)
return median_clipped


@jit
def jax_rbf_kernel(x1, x2, sigma, correlation_length):
"""
Computes the radial basis function (RBF) kernel matrix.
Parameters:
-----------
x1 : `np.array`
Location of training data point with shape (n_samples, n_features).
x2 : `np.array`
Location of training/test data point with shape (n_samples, n_features).
sigma : `float`
The scale parameter of the kernel.
correlation_length : `float`
The correlation length parameter of the kernel.
Returns:
--------
kernel : `np.array`
RBF kernel matrix with shape (n_samples, n_samples).
"""
distance_squared = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2))
return kernel


@jit
def jax_get_alpha(y, kernel):
"""
Compute the alpha vector for Gaussian Process interpolation.
Parameters:
-----------
y : `np.array`
The target values of the Gaussian Process.
kernel : `np.array`
The kernel matrix of the Gaussian Process.
Returns:
--------
alpha : `np.array`
The alpha vector computed using the Cholesky decomposition and solution.
"""
factor = (jax.scipy.linalg.cholesky(kernel, overwrite_a=True, lower=False), False)
alpha = jax.scipy.linalg.cho_solve(factor, y, overwrite_b=False)
return alpha.reshape((len(alpha), 1))


@jit
def jax_get_gp_predict(kernel_rect, alpha):
"""
Compute the predicted values of gp using the given kernel and alpha (cholesky solution).
Parameters:
-----------
kernel_rect : `np.array`
The kernel matrix.
alpha : `np.array`
The alpha vector from Cholesky solution.
Returns:
--------
`np.array`
The predicted values of y.
"""
return jnp.dot(kernel_rect, alpha).T[0]


class GaussianProcessJax:
"""
Gaussian Process regression in JAX.
Kernel is assumed to be isotropic RBF kernel, and solved
using exact Cholesky decomposition.
The interpolation solution is obtained by solving the linear system:
y_interp = kernel_rect @ (kernel + y_err**2 * I)^-1 @ y_training.
See the Rasmussen and Williams book for more details.
Each function is decorated with @jit to compile the function.
Exist package like tinygp, that is implemented in jax also.
This class is a custom implementation
of Gaussian Processes, which allows for setting the hyperparameters,
fine-tuning the mean function,
and other specifications.
Parameters:
-----------
std : `float`, optional
Standard deviation of the Gaussian Process kernel. Default is 1.0.
correlation_length : `float`, optional
Correlation length of the Gaussian Process kernel. Default is 1.0.
white_noise : `float`, optional
White noise level of the Gaussian Process. Default is 0.0.
mean : `float`, optional
Mean value of the Gaussian Process. Default is 0.0.
"""

def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
self.std = std
self.correlation_length = correlation_length
self.white_noise = white_noise
self.mean = mean
self._alpha = None

# Looks weird to do that, but this is justified.
# in GP if no noise is provided, even if matrix
# can be inverted, it wont invert because of numerical
# issue (det(K)~0). Add a little bit of noise allow
# to compute a numerical solution in the case of no
# external noise is added. Wont happened on real
# image but help for unit test.
if self.white_noise == 0.0:
self.white_noise = 1e-5

def fit(self, x_train, y_train):
y = y_train - self.mean
self._x = x_train
kernel = jax_rbf_kernel(x_train, x_train, self.std, self.correlation_length)
y_err = jnp.ones(len(x_train[:, 0])) * self.white_noise
kernel += jnp.eye(len(y_err)) * (y_err**2)
self._alpha = jax_get_alpha(y, kernel)

def predict(self, x_predict):
kernel_rect = jax_rbf_kernel(
x_predict, self._x, self.std, self.correlation_length
)
y_pred = jax_get_gp_predict(kernel_rect, self._alpha)
return y_pred + self.mean


class GaussianProcessTreegp:
"""
Gaussian Process Treegp class for Gaussian Process interpolation.
Expand Down Expand Up @@ -313,8 +176,6 @@ class InterpolateOverDefectGaussianProcess:
The masked image containing the defects to be interpolated.
defects : `list`[`str`], optional
The types of defects to be interpolated. Default is ["SAT"].
method : `str`, optional
The method to use for GP interpolation. Must be either "jax" or "treegp". Default is "treegp".
fwhm : `float`, optional
The full width at half maximum (FWHM) of the PSF. Default is 5.
bin_spacing : `int`, optional
Expand All @@ -334,7 +195,6 @@ def __init__(
self,
masked_image,
defects=["SAT"],
method="treegp",
fwhm=5,
bin_image=True,
bin_spacing=10,
Expand All @@ -343,12 +203,6 @@ def __init__(
correlation_length_cut=5,
log=None,
):
if method == "jax":
self.GaussianProcess = GaussianProcessJax
elif method == "treegp":
self.GaussianProcess = GaussianProcessTreegp
else:
raise ValueError("Invalid method. Must be 'jax' or 'treegp'.")

self.log = log or logging.getLogger(__name__)

Expand Down Expand Up @@ -485,7 +339,7 @@ def interpolate_masked_sub_image(self, masked_sub_image):
# put this after binning as computing median is O(n*log(n))
clipped_median = median_with_mad_clipping(good_pixel[:, 2:])

gp = self.GaussianProcess(
gp = GaussianProcessTreegp(
std=np.sqrt(kernel_amplitude),
correlation_length=self.correlation_length,
white_noise=white_noise,
Expand Down
4 changes: 1 addition & 3 deletions tests/test_gp_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ def setUp(self):
# np.random.seed(12345)
# self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape)

@lsst.utils.tests.methodParameters(method=("jax"))
def test_interpolation(self, method: str):
def test_interpolation(self):
"""Test that the interpolation is done correctly.
Parameters
Expand All @@ -155,7 +154,6 @@ def test_interpolation(self, method: str):
gp = InterpolateOverDefectGaussianProcess(
self.maskedimage,
defects=["BAD", "SAT", "CR", "EDGE"],
method=method,
fwhm=self.correlation_length,
bin_image=False,
bin_spacing=30,
Expand Down

0 comments on commit c081298

Please sign in to comment.