From 7d133e04c2a69123f02abd777badd7f26437865d Mon Sep 17 00:00:00 2001 From: Christopher Waters Date: Wed, 13 Mar 2024 12:22:22 -0700 Subject: [PATCH] add gaussian process as an option in interpolateOverDefects. --- python/lsst/meas/algorithms/__init__.py | 2 + .../lsst/meas/algorithms/gp_interpolation.py | 510 ++++++++++++++++++ python/lsst/meas/algorithms/interp.cc | 2 +- python/lsst/meas/algorithms/interp.py | 25 + .../meas/algorithms/reinterpolate_pixels.py | 2 +- tests/test_gp_interp.py | 193 +++++++ 6 files changed, 732 insertions(+), 2 deletions(-) create mode 100644 python/lsst/meas/algorithms/gp_interpolation.py create mode 100644 python/lsst/meas/algorithms/interp.py create mode 100644 tests/test_gp_interp.py diff --git a/python/lsst/meas/algorithms/__init__.py b/python/lsst/meas/algorithms/__init__.py index 561e281d7..e750993d2 100644 --- a/python/lsst/meas/algorithms/__init__.py +++ b/python/lsst/meas/algorithms/__init__.py @@ -61,6 +61,8 @@ from .accumulator_mean_stack import * from .scaleVariance import * from .noise_covariance import * +from .gp_interpolation import * +from .interp import * from .reinterpolate_pixels import * from .setPrimaryFlags import * from .coaddBoundedField import * diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py new file mode 100644 index 000000000..6de217f9a --- /dev/null +++ b/python/lsst/meas/algorithms/gp_interpolation.py @@ -0,0 +1,510 @@ +# This file is part of meas_algorithms. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import numpy as np +from lsst.meas.algorithms import CloughTocher2DInterpolatorUtils as ctUtils +from lsst.geom import Box2I, Point2I +from lsst.afw.geom import SpanSet +import copy +import treegp + +import jax +from jax import jit +import jax.numpy as jnp + +import logging + +__all__ = [ + "InterpolateOverDefectGaussianProcess", + "GaussianProcessJax", + "GaussianProcessTreegp", +] + + +def updateMaskFromArray(mask, bad_pixel, interpBit): + """ + Update the mask array with the given bad pixels. + + Parameters + ---------- + mask : `lsst.afw.image.MaskedImage` + The mask image to update. + bad_pixel : `np.array` + An array-like object containing the coordinates of the bad pixels. + Each row should contain the x and y coordinates of a bad pixel. + interpBit : `int` + The bit value to set for the bad pixels in the mask. + """ + x0 = mask.getX0() + y0 = mask.getY0() + for row in bad_pixel: + x = int(row[0] - x0) + y = int(row[1] - y0) + mask.array[y, x] |= interpBit + # TO DO --> might be better: mask.array[int(bad_pixel[:,1]-y0), int(bad_pixel[:,0]-x)] |= interpBit + + +@jit +def median_with_mad_clipping(data, mad_multiplier=2.0): + """ + Calculate the median of the input data after applying Median Absolute Deviation (MAD) clipping. + + The MAD clipping method is used to remove outliers from the data. The median of the data is calculated, + and then the MAD is calculated as the median absolute deviation from the median. The data is then clipped + by removing values that are outside the range of median +/- mad_multiplier * MAD. Finally, the median of + the clipped data is returned. + + Parameters: + ----------- + data : `np.array` + Input data array. + mad_multiplier : `float`, optional + Multiplier for the MAD value used for clipping. Default is 2.0. + + Returns: + -------- + median_clipped : `float` + Median value of the clipped data. + + Examples: + --------- + >>> data = [1, 2, 3, 4, 5, 100] + >>> median_with_mad_clipping(data) + 3.5 + """ + median = jnp.median(data) + mad = jnp.median(jnp.abs(data - median)) + clipping_range = mad_multiplier * mad + clipped_data = jnp.clip(data, median - clipping_range, median + clipping_range) + median_clipped = jnp.median(clipped_data) + return median_clipped + + +@jit +def jax_rbf_kernel(x1, x2, sigma, correlation_length): + """ + Computes the radial basis function (RBF) kernel matrix. + + Parameters: + ----------- + x1 : `np.array` + Location of training data point with shape (n_samples, n_features). + x2 : `np.array` + Location of training/test data point with shape (n_samples, n_features). + sigma : `float` + The scale parameter of the kernel. + correlation_length : `float` + The correlation length parameter of the kernel. + + Returns: + -------- + kernel : `np.array` + RBF kernel matrix with shape (n_samples, n_samples). + """ + distance_squared = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1) + kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2)) + return kernel + + +@jit +def jax_get_alpha(y, kernel): + """ + Compute the alpha vector for Gaussian Process interpolation. + + Parameters: + ----------- + y : `np.array` + The target values of the Gaussian Process. + kernel : `np.array` + The kernel matrix of the Gaussian Process. + + Returns: + -------- + alpha : `np.array` + The alpha vector computed using the Cholesky decomposition and solution. + + """ + factor = (jax.scipy.linalg.cholesky(kernel, overwrite_a=True, lower=False), False) + alpha = jax.scipy.linalg.cho_solve(factor, y, overwrite_b=False) + return alpha.reshape((len(alpha), 1)) + + +@jit +def jax_get_gp_predict(kernel_rect, alpha): + """ + Compute the predicted values of gp using the given kernel and alpha (cholesky solution). + + Parameters: + ----------- + kernel_rect : `np.array` + The kernel matrix. + alpha : `np.array` + The alpha vector from Cholesky solution. + + Returns: + -------- + `np.array` + The predicted values of y. + + """ + return jnp.dot(kernel_rect, alpha).T[0] + + +class GaussianProcessJax: + """ + Gaussian Process regression in JAX. + Kernel is assumed to be isotropic RBF kernel, and solved + using exact Cholesky decomposition. + The interpolation solution is obtained by solving the linear system: + y_interp = kernel_rect @ (kernel + y_err**2 * I)^-1 @ y_training. + See the Rasmussen and Williams book for more details. + Each function is decorated with @jit to compile the function. + Exist package like tinygp, that is implemented in jax also. + This class is a custom implementation + of Gaussian Processes, which allows for setting the hyperparameters, + fine-tuning the mean function, + and other specifications. + + Parameters: + ----------- + std : `float`, optional + Standard deviation of the Gaussian Process kernel. Default is 1.0. + correlation_length : `float`, optional + Correlation length of the Gaussian Process kernel. Default is 1.0. + white_noise : `float`, optional + White noise level of the Gaussian Process. Default is 0.0. + mean : `float`, optional + Mean value of the Gaussian Process. Default is 0.0. + + """ + + def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): + self.std = std + self.correlation_length = correlation_length + self.white_noise = white_noise + self.mean = mean + self._alpha = None + + # Looks weird to do that, but this is justified. + # in GP if no noise is provided, even if matrix + # can be inverted, it wont invert because of numerical + # issue (det(K)~0). Add a little bit of noise allow + # to compute a numerical solution in the case of no + # external noise is added. Wont happened on real + # image but help for unit test. + if self.white_noise == 0.0: + self.white_noise = 1e-5 + + def fit(self, x_train, y_train): + y = y_train - self.mean + self._x = x_train + kernel = jax_rbf_kernel(x_train, x_train, self.std, self.correlation_length) + y_err = jnp.ones(len(x_train[:, 0])) * self.white_noise + kernel += jnp.eye(len(y_err)) * (y_err**2) + self._alpha = jax_get_alpha(y, kernel) + + def predict(self, x_predict): + kernel_rect = jax_rbf_kernel( + x_predict, self._x, self.std, self.correlation_length + ) + y_pred = jax_get_gp_predict(kernel_rect, self._alpha) + return y_pred + self.mean + + +class GaussianProcessTreegp: + """ + Gaussian Process Treegp class for Gaussian Process interpolation. + + The basic GP regression, which uses Cholesky decomposition. + + Parameters: + ----------- + std : `float`, optional + Standard deviation of the Gaussian Process kernel. Default is 1.0. + correlation_length : `float`, optional + Correlation length of the Gaussian Process kernel. Default is 1.0. + white_noise : `float`, optional + White noise level of the Gaussian Process. Default is 0.0. + mean : `float`, optional + Mean value of the Gaussian Process. Default is 0.0. + """ + + def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): + self.std = std + self.correlation_length = correlation_length + self.white_noise = white_noise + self.mean = mean + + # Looks like weird to do that, but this is justified. + # in GP if no noise is provided, even if matrix + # can be inverted, it wont invert because of numerical + # issue (det(K)~0). Add a little bit of noise allow + # to compute a numerical solution in the case of no + # external noise is added. Wont happened on real + # image but help for unit test. + if self.white_noise == 0.0: + self.white_noise = 1e-5 + + def fit(self, x_train, y_train): + """ + Fit the Gaussian Process to the given training data. + + Parameters: + ----------- + x_train : `np.array` + Input features for the training data. + y_train : `np.array` + Target values for the training data. + """ + kernel = f"{self.std}**2 * RBF({self.correlation_length})" + self.gp = treegp.GPInterpolation( + kernel=kernel, + optimizer="none", + normalize=False, + white_noise=self.white_noise, + ) + self.gp.initialize(x_train, y_train - self.mean) + self.gp.solve() + + def predict(self, x_predict): + """ + Predict the target values for the given input features. + + Parameters: + ----------- + x_predict : `np.array` + Input features for the prediction. + + Returns: + -------- + y_pred : `np.array` + Predicted target values. + """ + y_pred = self.gp.predict(x_predict) + return y_pred + self.mean + + +class InterpolateOverDefectGaussianProcess: + """ + InterpolateOverDefectGaussianProcess class performs Gaussian Process + (GP) interpolation over defects in an image. + + Parameters: + ----------- + masked_image : `lsst.afw.image.MaskedImage` + The masked image containing the defects to be interpolated. + defects : `list`[`str`], optional + The types of defects to be interpolated. Default is ["SAT"]. + method : `str`, optional + The method to use for GP interpolation. Must be either "jax" or "treegp". Default is "treegp". + fwhm : `float`, optional + The full width at half maximum (FWHM) of the PSF. Default is 5. + bin_spacing : `int`, optional + The spacing between bins for good pixel binning. Default is 10. + threshold_dynamic_binning : `int`, optional + The threshold for dynamic binning. Default is 1000. + threshold_subdivide : `int`, optional + The threshold for sub-dividing the bad pixel array to avoid memory error. Default is 20000. + correlation_length_cut : `int`, optional + The factor by which to dilate the bounding box around defects. Default is 5. + log : `lsst.log.Log`, `logging.Logger` or `None`, optional + Logger object used to write out messages. If `None` a default + logger will be used. + """ + + def __init__( + self, + masked_image, + defects=["SAT"], + method="treegp", + fwhm=5, + bin_image=True, + bin_spacing=10, + threshold_dynamic_binning=1000, + threshold_subdivide=20000, + correlation_length_cut=5, + log=None, + ): + if method == "jax": + self.GaussianProcess = GaussianProcessJax + elif method == "treegp": + self.GaussianProcess = GaussianProcessTreegp + else: + raise ValueError("Invalid method. Must be 'jax' or 'treegp'.") + + self.log = log or logging.getLogger(__name__) + + self.bin_image = bin_image + self.bin_spacing = bin_spacing + self.threshold_subdivide = threshold_subdivide + self.threshold_dynamic_binning = threshold_dynamic_binning + + self.masked_image = masked_image + self.defects = defects + self.correlation_length = fwhm + self.correlation_length_cut = correlation_length_cut + + self.interpBit = self.masked_image.mask.getPlaneBitMask("INTRP") + + def run(self): + """ + Interpolate over the defects in the image. + + Change self.masked_image . + """ + if self.defects == [] or self.defects is None: + self.log.info("No defects found. No interpolation performed.") + else: + mask = self.masked_image.getMask() + bad_pixel_mask = mask.getPlaneBitMask(self.defects) + bad_mask_span_set = SpanSet.fromMask(mask, bad_pixel_mask).split() + + bbox = self.masked_image.getBBox() + global_xmin, global_xmax = bbox.minX, bbox.maxX + global_ymin, global_ymax = bbox.minY, bbox.maxY + + for spanset in bad_mask_span_set: + bbox = spanset.getBBox() + # Dilate the bbox to make sure we have enough good pixels around the defect + # For now, we dilate by 5 times the correlation length + # For GP with the isotropic kernel, points at the default value of + # correlation_length_cut=5 have negligible effect on the prediction. + bbox = bbox.dilatedBy( + int(self.correlation_length * self.correlation_length_cut) + ) # need integer as input. + xmin, xmax = max([global_xmin, bbox.minX]), min(global_xmax, bbox.maxX) + ymin, ymax = max([global_ymin, bbox.minY]), min(global_ymax, bbox.maxY) + localBox = Box2I(Point2I(xmin, ymin), Point2I(xmax - xmin, ymax - ymin)) + masked_sub_image = self.masked_image[localBox] + + masked_sub_image = self.interpolate_masked_sub_image(masked_sub_image) + self.masked_image[localBox] = masked_sub_image + + def _good_pixel_binning(self, pixels): + """ + Performs pixel binning using treegp.meanify + + Parameters: + ----------- + pixels : `np.array` + The array of pixels. + + Returns: + -------- + `np.array` + The binned array of pixels. + """ + + n_pixels = len(pixels[:, 0]) + dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning)) + if n_pixels / self.bin_spacing**2 < n_pixels / dynamic_binning**2: + bin_spacing = self.bin_spacing + else: + bin_spacing = dynamic_binning + binning = treegp.meanify(bin_spacing=bin_spacing, statistics="mean") + binning.add_field( + pixels[:, :2], + pixels[:, 2:].T, + ) + binning.meanify() + return np.array( + [binning.coords0[:, 0], binning.coords0[:, 1], binning.params0] + ).T + + def interpolate_masked_sub_image(self, masked_sub_image): + """ + Interpolate the masked sub-image. + + Parameters: + ----------- + masked_sub_image : `lsst.afw.image.MaskedImage` + The sub-masked image to be interpolated. + + Returns: + -------- + `lsst.afw.image.MaskedImage` + The interpolated sub-masked image. + """ + + cut = int( + self.correlation_length * self.correlation_length_cut + ) # need integer as input. + bad_pixel, good_pixel = ctUtils.findGoodPixelsAroundBadPixels( + masked_sub_image, self.defects, buffer=cut + ) + # Do nothing if bad pixel is None. + if bad_pixel.size == 0 or good_pixel.size == 0: + self.log.info("No bad or good pixels found. No interpolation performed.") + return masked_sub_image + # Do GP interpolation if bad pixel found. + else: + # gp interpolation + sub_image_array = masked_sub_image.getVariance().array + white_noise = np.sqrt( + np.mean(sub_image_array[np.isfinite(sub_image_array)]) + ) + kernel_amplitude = np.max(good_pixel[:, 2:]) + if not np.isfinite(kernel_amplitude): + filter_finite = np.isfinite(good_pixel[:, 2:]).T[0] + good_pixel = good_pixel[filter_finite] + if good_pixel.size == 0: + self.log.info( + "No bad or good pixels found. No interpolation performed." + ) + return masked_sub_image + # kernel amplitude might be better described by maximum value of good pixel given + # the data and not really a random gaussian field. + kernel_amplitude = np.max(good_pixel[:, 2:]) + + if self.bin_image: + try: + good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel)) + except Exception: + self.log.info( + "Binning failed, use original good pixel array in interpolation." + ) + + # put this after binning as computing median is O(n*log(n)) + clipped_median = median_with_mad_clipping(good_pixel[:, 2:]) + + gp = self.GaussianProcess( + std=np.sqrt(kernel_amplitude), + correlation_length=self.correlation_length, + white_noise=white_noise, + mean=clipped_median, + ) + gp.fit(good_pixel[:, :2], np.squeeze(good_pixel[:, 2:])) + if bad_pixel.size < self.threshold_subdivide: + gp_predict = gp.predict(bad_pixel[:, :2]) + bad_pixel[:, 2:] = gp_predict.reshape(np.shape(bad_pixel[:, 2:])) + else: + self.log.info("sub-divide bad pixel array to avoid memory error.") + for i in range(0, len(bad_pixel), self.threshold_subdivide): + end = min(i + self.threshold_subdivide, len(bad_pixel)) + gp_predict = gp.predict(bad_pixel[i:end, :2]) + bad_pixel[i:end, 2:] = gp_predict.reshape( + np.shape(bad_pixel[i:end, 2:]) + ) + + # Update values + ctUtils.updateImageFromArray(masked_sub_image.image, bad_pixel) + updateMaskFromArray(masked_sub_image.mask, bad_pixel, self.interpBit) + return masked_sub_image diff --git a/python/lsst/meas/algorithms/interp.cc b/python/lsst/meas/algorithms/interp.cc index 885226d40..142b41976 100644 --- a/python/lsst/meas/algorithms/interp.cc +++ b/python/lsst/meas/algorithms/interp.cc @@ -37,7 +37,7 @@ namespace { template void declareInterpolateOverDefects(py::module& mod) { - mod.def("interpolateOverDefects", + mod.def("legacyInterpolateOverDefects", interpolateOverDefects< afw::image::MaskedImage>, "image"_a, "psf"_a, "badList"_a, "fallBackValue"_a = 0.0, "useFallbackValueAtEdge"_a = false); diff --git a/python/lsst/meas/algorithms/interp.py b/python/lsst/meas/algorithms/interp.py new file mode 100644 index 000000000..554a3cf92 --- /dev/null +++ b/python/lsst/meas/algorithms/interp.py @@ -0,0 +1,25 @@ +from . import legacyInterpolateOverDefects +from . import InterpolateOverDefectGaussianProcess + +__all__ = ["interpolateOverDefects"] + + +def interpolateOverDefects( + image, + psf, + badList, + fallbackValue=0.0, + useFallbackValueAtEdge=False, + fwhm=1.0, + useLegacyInterp=True, + maskNameList=None, + **kwargs +): + if useLegacyInterp: + return legacyInterpolateOverDefects( + image, psf, badList, fallbackValue, useFallbackValueAtEdge + ) + else: + gp = InterpolateOverDefectGaussianProcess(image, fwhm=fwhm, + defects=maskNameList, **kwargs) + return gp.run() diff --git a/python/lsst/meas/algorithms/reinterpolate_pixels.py b/python/lsst/meas/algorithms/reinterpolate_pixels.py index b04c2ccd3..a8719f94d 100644 --- a/python/lsst/meas/algorithms/reinterpolate_pixels.py +++ b/python/lsst/meas/algorithms/reinterpolate_pixels.py @@ -27,7 +27,7 @@ import lsst.afw.math as afwMath import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase -from lsst.meas.algorithms import Defect, interpolateOverDefects +from . import Defect, interpolateOverDefects class ReinterpolatePixelsConfig(pexConfig.Config): diff --git a/tests/test_gp_interp.py b/tests/test_gp_interp.py new file mode 100644 index 000000000..19f5c716b --- /dev/null +++ b/tests/test_gp_interp.py @@ -0,0 +1,193 @@ +# This file is part of meas_algorithms. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +import unittest + +import numpy as np + +import lsst.utils.tests +import lsst.geom +import lsst.afw.image as afwImage +from lsst.meas.algorithms import ( + InterpolateOverDefectGaussianProcess, + GaussianProcessTreegp, +) + + +def rbf_kernel(x1, x2, sigma, correlation_length): + """ + Computes the radial basis function (RBF) kernel matrix. + + Parameters: + ----------- + x1 : `np.array` + Location of training data point with shape (n_samples, n_features). + x2 : `np.array` + Location of training/test data point with shape (n_samples, n_features). + sigma : `float` + The scale parameter of the kernel. + correlation_length : `float` + The correlation length parameter of the kernel. + + Returns: + -------- + kernel : `np.array` + RBF kernel matrix with shape (n_samples, n_samples). + """ + distance_squared = np.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1) + kernel = (sigma**2) * np.exp(-0.5 * distance_squared / (correlation_length**2)) + return kernel + + +class InterpolateOverDefectGaussianProcessTestCase(lsst.utils.tests.TestCase): + """Test InterpolateOverDefectGaussianProcess.""" + + def setUp(self): + super().setUp() + + npoints = 1000 + self.std = 100 + self.correlation_length = 10.0 + self.white_noise = 1e-5 + + x1 = np.random.uniform(0, 99, npoints) + x2 = np.random.uniform(0, 120, npoints) + coord1 = np.array([x1, x2]).T + + kernel = rbf_kernel(coord1, coord1, self.std, self.correlation_length) + kernel += np.eye(npoints) * self.white_noise**2 + + # Data augmentation. Create a gaussian random field + # on a 100 * 100 is to slow. So generate 1e3 points + # and then interpolate it with a GP to do data augmentation. + + np.random.seed(42) + z1 = np.random.multivariate_normal(np.zeros(npoints), kernel) + + x1 = np.linspace(0, 99, 100) + x2 = np.linspace(0, 120, 121) + x2, x1 = np.meshgrid(x2, x1) + coord2 = np.array([x1.reshape(-1), x2.reshape(-1)]).T + + tgp = GaussianProcessTreegp( + std=self.std, + correlation_length=self.correlation_length, + white_noise=self.white_noise, + mean=0.0, + ) + tgp.fit(coord1, z1) + z2 = tgp.predict(coord2) + z2 = z2.reshape(100, 121) + + self.maskedimage = afwImage.MaskedImageF(100, 121) + for x in range(100): + for y in range(121): + self.maskedimage[x, y] = (z2[x, y], 0, 1.0) + + # Clone the maskedimage so we can compare it after running the task. + self.reference = self.maskedimage.clone() + + # Set some central pixels as SAT + sliceX, sliceY = slice(30, 35), slice(40, 45) + self.maskedimage.mask[sliceX, sliceY] = afwImage.Mask.getPlaneBitMask("SAT") + self.maskedimage.image[sliceX, sliceY] = np.nan + # Put nans here to make sure interp is done ok + + # Set an entire column as BAD + self.maskedimage.mask[54:55, :] = afwImage.Mask.getPlaneBitMask("BAD") + self.maskedimage.image[54:55, :] = np.nan + + # Set an entire row as BAD + self.maskedimage.mask[:, 110:111] = afwImage.Mask.getPlaneBitMask("BAD") + self.maskedimage.image[:, 110:111] = np.nan + + # Set a diagonal set of pixels as CR + for i in range(74, 78): + self.maskedimage.mask[i, i] = afwImage.Mask.getPlaneBitMask("CR") + self.maskedimage.image[i, i] = np.nan + + # Set one of the edges as EDGE + self.maskedimage.mask[0:1, :] = afwImage.Mask.getPlaneBitMask("EDGE") + self.maskedimage.image[0:1, :] = np.nan + + # Set a smaller streak at the edge + self.maskedimage.mask[25:28, 0:1] = afwImage.Mask.getPlaneBitMask("EDGE") + self.maskedimage.image[25:28, 0:1] = np.nan + + # Update the reference image's mask alone, so we can compare them after + # running the task. + self.reference.mask.array[:, :] = self.maskedimage.mask.array + + # Create a noise image + # self.noise = self.maskedimage.clone() + # np.random.seed(12345) + # self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape) + + @lsst.utils.tests.methodParameters(method=("jax")) + def test_interpolation(self, method: str): + """Test that the interpolation is done correctly. + + Parameters + ---------- + method : `str` + Code used to solve gaussian process. + """ + + gp = InterpolateOverDefectGaussianProcess( + self.maskedimage, + defects=["BAD", "SAT", "CR", "EDGE"], + method=method, + fwhm=self.correlation_length, + bin_image=False, + bin_spacing=30, + threshold_dynamic_binning=1000, + threshold_subdivide=20000, + correlation_length_cut=5, + log=None, + ) + + gp.run() + + # Assert that the mask and the variance planes remain unchanged. + self.assertImagesEqual(self.maskedimage.variance, self.reference.variance) + + # Check that interpolated pixels are close to the reference (original), + # and that none of them is still NaN. + self.assertTrue(np.isfinite(self.maskedimage.image.array).all()) + self.assertImagesAlmostEqual( + self.maskedimage.image[1:, :], + self.reference.image[1:, :], + atol=2, + ) + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main()