Skip to content

Commit

Permalink
black / flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
PFLeget committed Jun 19, 2024
1 parent a4f10d8 commit e451ab4
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 48 deletions.
113 changes: 71 additions & 42 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def updateMaskFromArray(mask, bad_pixel, interpBit):
y = int(row[1] - y0)
mask.array[y, x] |= interpBit


@jit
def median_with_mad_clipping(data, mad_multiplier=2.0):
"""
Expand Down Expand Up @@ -76,6 +77,7 @@ def median_with_mad_clipping(data, mad_multiplier=2.0):
median_clipped = jnp.median(clipped_data)
return median_clipped


# Below are the jax functions for Gaussian Process regression.
# Kernel is assumed to be isotropic RBF kernel, and solved
# using exact Cholesky decomposition.
Expand All @@ -88,6 +90,7 @@ def median_with_mad_clipping(data, mad_multiplier=2.0):
# of custom things (setting my own hyperparameters, fine tune mean function,
# dynamic binning, ...).


@jit
def jax_pdist_squared(x):
"""
Expand Down Expand Up @@ -118,6 +121,7 @@ def jax_pdist_squared(x):
"""
return jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1)


@jit
def jax_cdist_squared(xa, xb):
"""
Expand Down Expand Up @@ -152,6 +156,7 @@ def jax_cdist_squared(xa, xb):
"""
return jnp.sum((xa[:, None, :] - xb[None, :, :]) ** 2, axis=-1)


@jit
def jax_rbf_kernel(x, sigma, correlation_length, y_err):
"""
Expand All @@ -175,10 +180,11 @@ def jax_rbf_kernel(x, sigma, correlation_length, y_err):
"""
distance_squared = jax_pdist_squared(x)
kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2))
y_err = jnp.ones(len(x[:,0])) * y_err
y_err = jnp.ones(len(x[:, 0])) * y_err
kernel += jnp.eye(len(y_err)) * (y_err**2)
return kernel


@jit
def jax_rbf_kernel_rect(x1, x2, sigma, correlation_length):
"""
Expand Down Expand Up @@ -227,6 +233,7 @@ def jax_get_alpha(y, kernel):
alpha = jax.scipy.linalg.cho_solve(factor, y, overwrite_b=False)
return alpha.reshape((len(alpha), 1))


@jit
def jax_get_y_predict(kernel_rect, alpha):
"""
Expand All @@ -250,27 +257,24 @@ def jax_get_y_predict(kernel_rect, alpha):

class GaussianProcessJax:
def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):

self.std = std
self.l = correlation_length
self.correlation_lenght = correlation_length
self.white_noise = white_noise
self.mean = mean
self._alpha = None

def fit(self, x_good, y_good):

y = y_good - self.mean
self._x = x_good
kernel = jax_rbf_kernel(x_good, self.std, self.l, self.white_noise)
kernel = jax_rbf_kernel(x_good, self.std, self.correlation_lenght, self.white_noise)
self._alpha = jax_get_alpha(y, kernel)


def predict(self, x_bad):

kernel_rect = jax_rbf_kernel_rect(x_bad, self._x, self.std, self.l)
kernel_rect = jax_rbf_kernel_rect(x_bad, self._x, self.std, self.correlation_lenght)
y_pred = jax_get_y_predict(kernel_rect, self._alpha)
return y_pred + self.mean


# Vanilla Gaussian Process regression using treegp package
# There is no fancy O(N*log(N)) solver here, just the basic GP regression (Cholesky).
class GaussianProcessTreegp:
Expand All @@ -292,7 +296,7 @@ class GaussianProcessTreegp:
--------
fit(x_good, y_good):
Fit the Gaussian Process to the given training data.
Parameters:
-----------
x_good : array-like
Expand All @@ -302,7 +306,7 @@ class GaussianProcessTreegp:
predict(x_bad):
Predict the target values for the given input features.
Parameters:
-----------
x_bad : array-like
Expand All @@ -317,7 +321,7 @@ class GaussianProcessTreegp:

def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
self.std = std
self.l = correlation_length
self.correlation_length = correlation_length
self.white_noise = white_noise
self.mean = mean

Expand All @@ -332,7 +336,7 @@ def fit(self, x_good, y_good):
y_good : array-like
Target values for the training data.
"""
kernel = f"{self.std}**2 * RBF({self.l})"
kernel = f"{self.std}**2 * RBF({self.correlation_length})"
self.gp = treegp.GPInterpolation(
kernel=kernel,
optimizer="none",
Expand Down Expand Up @@ -362,7 +366,8 @@ def predict(self, x_bad):

class InterpolateOverDefectGaussianProcess:
"""
InterpolateOverDefectGaussianProcess class performs Gaussian Process (GP) interpolation over defects in an image.
InterpolateOverDefectGaussianProcess class performs Gaussian Process
(GP) interpolation over defects in an image.
Parameters:
-----------
Expand Down Expand Up @@ -429,7 +434,6 @@ def __init__(
threshold_subdivide=20000,
correlation_length_cut=5,
):

if method == "jax":
self.GaussianProcess = GaussianProcessJax
elif method == "treegp":
Expand Down Expand Up @@ -467,18 +471,18 @@ def interpolate_over_defects(self):
# For now, we dilate by 5 times the correlation length
# For GP with isotropic kernel, points at 5 correlation lengths away have negligible
# effect on the prediction.
bbox = bbox.dilatedBy(int(self.correlation_length * self.correlation_length_cut)) # need integer as input.
bbox = bbox.dilatedBy(
int(self.correlation_length * self.correlation_length_cut)
) # need integer as input.
xmin, xmax = max([glob_xmin, bbox.minX]), min(glob_xmax, bbox.maxX)
ymin, ymax = max([glob_ymin, bbox.minY]), min(glob_ymax, bbox.maxY)
localBox = Box2I(Point2I(xmin, ymin), Extent2I(xmax - xmin, ymax - ymin))
try:
sub_masked_image = self.maskedImage[localBox]
except:
except IndexError:
raise ValueError("Sub-masked image not found.")

sub_masked_image = self.interpolate_sub_masked_image(
sub_masked_image
)
sub_masked_image = self.interpolate_sub_masked_image(sub_masked_image)
self.maskedImage[localBox] = sub_masked_image

def _good_pixel_binning(self, good_pixel):
Expand All @@ -496,17 +500,21 @@ def _good_pixel_binning(self, good_pixel):
The binned array of good pixels.
"""

n_pixels = len(good_pixel[:,0])
n_pixels = len(good_pixel[:, 0])
dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning))
if n_pixels/self.bin_spacing**2 < n_pixels/dynamic_binning**2:
if n_pixels / self.bin_spacing**2 < n_pixels / dynamic_binning**2:
bin_spacing = self.bin_spacing
else:
bin_spacing = dynamic_binning
binning = treegp.meanify(bin_spacing=bin_spacing, statistics='mean')
binning.add_field(good_pixel[:, :2], good_pixel[:, 2:].T,)
binning = treegp.meanify(bin_spacing=bin_spacing, statistics="mean")
binning.add_field(
good_pixel[:, :2],
good_pixel[:, 2:].T,
)
binning.meanify()
return np.array([binning.coords0[:, 0], binning.coords0[:, 1], binning.params0]).T

return np.array(
[binning.coords0[:, 0], binning.coords0[:, 1], binning.params0]
).T

def interpolate_sub_masked_image(self, sub_masked_image):
"""
Expand All @@ -523,13 +531,15 @@ def interpolate_sub_masked_image(self, sub_masked_image):
The interpolated sub-masked image.
"""

cut = int(self.correlation_length * self.correlation_length_cut) # need integer as input.
cut = int(
self.correlation_length * self.correlation_length_cut
) # need integer as input.
bad_pixel, good_pixel = ctUtils.findGoodPixelsAroundBadPixels(
sub_masked_image, self.defects, buffer=cut
)
# Do nothing if bad pixel is None.
if bad_pixel.size == 0 or good_pixel.size == 0:
warnings.warn('No bad or good pixels found. No interpolation performed.')
warnings.warn("No bad or good pixels found. No interpolation performed.")
return sub_masked_image
# Do GP interpolation if bad pixel found.
else:
Expand All @@ -543,7 +553,9 @@ def interpolate_sub_masked_image(self, sub_masked_image):
filter_finite = np.isfinite(good_pixel[:, 2:]).T[0]
good_pixel = good_pixel[filter_finite]
if good_pixel.size == 0:
warnings.warn('No bad or good pixels found. No interpolation performed.')
warnings.warn(
"No bad or good pixels found. No interpolation performed."
)
return sub_masked_image
# kernel amplitude might be better described by maximum value of good pixel given
# the data and not really a random gaussian field.
Expand All @@ -552,8 +564,10 @@ def interpolate_sub_masked_image(self, sub_masked_image):
kernel_amplitude = np.max(good_pixel[:, 2:])
try:
good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel))
except:
warnings.warn('Binning failed, use original good pixel array in interpolate over.')
except Exception:
warnings.warn(
"Binning failed, use original good pixel array in interpolate over."
)

# put this after binning as comupting median is O(n*log(n))
mean = median_with_mad_clipping(good_pixel[:, 2:])
Expand All @@ -569,19 +583,29 @@ def interpolate_sub_masked_image(self, sub_masked_image):
gp_predict = gp.predict(bad_pixel[:, :2])
bad_pixel[:, 2:] = gp_predict.reshape(np.shape(bad_pixel[:, 2:]))
else:
warnings.warn('sub-divide bad pixel array to avoid memory error.')
warnings.warn("sub-divide bad pixel array to avoid memory error.")
for i in range(0, len(bad_pixel), self.threshold_subdivide):
end = min(i + self.threshold_subdivide, len(bad_pixel))
gp_predict = gp.predict(bad_pixel[i : end, :2])
bad_pixel[i : end, 2:] = gp_predict.reshape(np.shape(bad_pixel[i : end, 2:]))
end = min(i + self.threshold_subdivide, len(bad_pixel))
gp_predict = gp.predict(bad_pixel[i:end, :2])
bad_pixel[i:end, 2:] = gp_predict.reshape(
np.shape(bad_pixel[i:end, 2:])
)

# update_value
ctUtils.updateImageFromArray(sub_masked_image.image, bad_pixel)
updateMaskFromArray(sub_masked_image.mask, bad_pixel, self.interpBit)
return sub_masked_image

def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing=25,
threshold_dynamic_binning=1000, threshold_subdivide=20000):


def interpolateOverDefectsGP(
image,
fwhm,
badList,
method="treegp",
bin_spacing=25,
threshold_dynamic_binning=1000,
threshold_subdivide=20000,
):
"""
Interpolates over defects in an image using Gaussian Process interpolation.
Expand Down Expand Up @@ -619,10 +643,15 @@ def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing=
"""
if badList == [] or badList is None:
warnings.warn('WARNING: no defects found. No interpolation performed.')
warnings.warn("WARNING: no defects found. No interpolation performed.")
return
gp = InterpolateOverDefectGaussianProcess(image, defects=badList, method=method,
fwhm=fwhm, bin_spacing=bin_spacing,
threshold_dynamic_binning=threshold_dynamic_binning,
threshold_subdivide=threshold_subdivide)
gp = InterpolateOverDefectGaussianProcess(
image,
defects=badList,
method=method,
fwhm=fwhm,
bin_spacing=bin_spacing,
threshold_dynamic_binning=threshold_dynamic_binning,
threshold_subdivide=threshold_subdivide,
)
gp.interpolate_over_defects()
22 changes: 16 additions & 6 deletions python/lsst/meas/algorithms/interp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from . import interpolateOverDefectsOld
from . import interpolateOverDefectsGP

__all__ = ['interpolateOverDefects']
__all__ = ["interpolateOverDefects"]


def interpolateOverDefects(image, psf, badList, fallbackValue=0.0, fwhm=1.0,
useFallbackValueAtEdge=False, useLegacyInterp=False,
maskNameList=None, **kwargs):
def interpolateOverDefects(
image,
psf,
badList,
fallbackValue=0.0,
fwhm=1.0,
useFallbackValueAtEdge=False,
useLegacyInterp=False,
maskNameList=None,
**kwargs
):
if useLegacyInterp:
return interpolateOverDefectsOld(image, psf, badList, fallbackValue, useFallbackValueAtEdge)
return interpolateOverDefectsOld(
image, psf, badList, fallbackValue, useFallbackValueAtEdge
)
else:
return interpolateOverDefectsGP(image, fwhm, maskNameList, **kwargs)
return interpolateOverDefectsGP(image, fwhm, maskNameList, **kwargs)

0 comments on commit e451ab4

Please sign in to comment.