Skip to content

Commit

Permalink
accelerate comp median and add dynamic binning | change api
Browse files Browse the repository at this point in the history
  • Loading branch information
PFLeget committed Jun 4, 2024
1 parent 91b6d1a commit 0698454
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
27 changes: 18 additions & 9 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def __init__(
method="treegp",
fwhm=5,
bin_spacing=10,
threshold_dynamic_binning=1000,
threshold_subdivide=20000,
):
"""
Expand All @@ -250,6 +251,7 @@ def __init__(

self.bin_spacing = bin_spacing
self.threshold_subdivide = threshold_subdivide
self.threshold_dynamic_binning = threshold_dynamic_binning

self.maskedImage = maskedImage
self.defects = defects
Expand Down Expand Up @@ -286,15 +288,12 @@ def interpolate_over_defects(self):
sub_masked_image = self.maskedImage[localBox]
except:
raise ValueError("Sub-masked image not found.")
# try:

sub_masked_image = self.interpolate_sub_masked_image(
sub_masked_image
)
# except:
# raise ValueError("Interpolation failed.")
self.maskedImage[localBox] = sub_masked_image


def _good_pixel_binning(self, good_pixel):
"""
Performs binning of good pixel data.
Expand All @@ -306,11 +305,18 @@ def _good_pixel_binning(self, good_pixel):
- numpy.ndarray: An array containing the binned data.
"""
binning = treegp.meanify(bin_spacing=self.bin_spacing, statistics='mean')
n_pixels = len(good_pixel[:,0])
dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning))
if n_pixels/self.bin_spacing**2 < n_pixels/dynamic_binning**2:
bin_spacing = self.bin_spacing
else:
bin_spacing = dynamic_binning
binning = treegp.meanify(bin_spacing=bin_spacing, statistics='mean')
binning.add_field(good_pixel[:, :2], good_pixel[:, 2:].T,)
binning.meanify()
return np.array([binning.coords0[:, 0], binning.coords0[:, 1], binning.params0]).T


def interpolate_sub_masked_image(self, sub_masked_image):
"""
Interpolates over defects in a sub-masked image.
Expand All @@ -333,7 +339,6 @@ def interpolate_sub_masked_image(self, sub_masked_image):
# Do GP interpolation if bad pixel found.
else:
# gp interpolation
mean = median_with_mad_clipping(good_pixel[:, 2:]) # np.mean(good_pixel[:, 2:])
sub_image_array = sub_masked_image.getVariance().array
white_noise = np.sqrt(
np.mean(sub_image_array[np.isfinite(sub_image_array)])
Expand All @@ -346,12 +351,14 @@ def interpolate_sub_masked_image(self, sub_masked_image):
warnings.warn('No bad or good pixels found. No interpolation performed.')
return sub_masked_image
kernel_amplitude = np.std(good_pixel[:, 2:])
mean = np.mean(good_pixel[:, 2:])
try:
good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel))
except:
warnings.warn('Binning failed, use original good pixel array in interpolate over.')

# put this after binning as comupting median is O(n*log(n))
mean = median_with_mad_clipping(good_pixel[:, 2:])

gp = self.GaussianProcess(
std=np.sqrt(kernel_amplitude),
correlation_length=self.correlation_length,
Expand All @@ -374,7 +381,8 @@ def interpolate_sub_masked_image(self, sub_masked_image):
updateMaskFromArray(sub_masked_image.mask, bad_pixel, self.interpBit)
return sub_masked_image

def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing=25, threshold_subdivide=20000):
def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing=25,
threshold_dynamic_binning=1000, threshold_subdivide=20000):
"""
Interpolates over defects in an image using Gaussian Process interpolation.
Expand Down Expand Up @@ -406,6 +414,7 @@ def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing=
warnings.warn('WARNING: no defects found. No interpolation performed.')
return
gp = InterpolateOverDefectGaussianProcess(image, defects=badList, method=method,
fwhm=fwhm, bin_spacing=bin_spacing,
fwhm=fwhm, bin_spacing=bin_spacing,
threshold_dynamic_binning=threshold_dynamic_binning,
threshold_subdivide=threshold_subdivide)
gp.interpolate_over_defects()
8 changes: 4 additions & 4 deletions python/lsst/meas/algorithms/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


def interpolateOverDefects(image, psf, badList, fallbackValue=0.0, fwhm=1.0,
useFallbackValueAtEdge=False, useGaussianProcess=False,
useFallbackValueAtEdge=False, useLegacyInterp=False,
maskNameList=None, **kwargs):
if useGaussianProcess:
return interpolateOverDefectsGP(image, fwhm, maskNameList, **kwargs)
else:
if useLegacyInterp:
return interpolateOverDefectsOld(image, psf, badList, fallbackValue, useFallbackValueAtEdge)
else:
return interpolateOverDefectsGP(image, fwhm, maskNameList, **kwargs)

0 comments on commit 0698454

Please sign in to comment.