Skip to content

Commit

Permalink
Merge branch 'tickets/DM-45361'
Browse files Browse the repository at this point in the history
  • Loading branch information
mrawls committed Nov 21, 2024
2 parents f3fb38a + 2f3b1b4 commit fec2e1e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
48 changes: 43 additions & 5 deletions python/lsst/ip/diffim/detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import numpy as np

import lsst.afw.detection as afwDetection
import lsst.afw.image as afwImage
import lsst.afw.math as afwMath
import lsst.afw.table as afwTable
import lsst.daf.base as dafBase
import lsst.geom
Expand Down Expand Up @@ -169,6 +171,12 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
doc="Subtask for masking streaks. Only used if doMaskStreaks is True. "
"Adds a mask plane to an exposure, with the mask plane name set by streakMaskName.",
)
streakBinFactor = pexConfig.Field(
dtype=int,
default=4,
doc="Bin scale factor to use when rerunning detection for masking streaks. "
"Only used if doMaskStreaks is True.",
)
writeStreakInfo = pexConfig.Field(
dtype=bool,
default=False,
Expand Down Expand Up @@ -444,7 +452,7 @@ def processResults(self, science, matchedTemplate, difference, sources, idFactor
self.metadata["nMergedDiaSources"] = len(initialDiaSources)

if self.config.doMaskStreaks:
streakInfo = self._runStreakMasking(difference.maskedImage)
streakInfo = self._runStreakMasking(difference)

if self.config.doSkySources:
self.addSkySources(initialDiaSources, difference.mask, difference.info.id)
Expand Down Expand Up @@ -657,14 +665,20 @@ def calculateMetrics(self, difference):
self.metadata["%s_mask_fraction"%maskPlane.lower()] = -1
self.log.info("Unable to calculate metrics for mask plane %s: not in image"%maskPlane)

def _runStreakMasking(self, maskedImage):
def _runStreakMasking(self, difference):
"""Do streak masking and optionally save the resulting streak
fit parameters in a catalog.
Only returns non-empty streakInfo if self.config.writeStreakInfo
is set. The difference image is binned by self.config.streakBinFactor
(and detection is run a second time) so that regions with lower
surface brightness streaks are more likely to fall above the
detection threshold.
Parameters
----------
maskedImage: `lsst.afw.image.maskedImage`
The image in which to search for streaks. Must have a detection
difference: `lsst.afw.image.Exposure`
The exposure in which to search for streaks. Must have a detection
mask.
Returns
Expand All @@ -681,7 +695,31 @@ def _runStreakMasking(self, maskedImage):
``modelMaximum`` : `np.ndarray`
Peak value of the fit line profile.
"""
streaks = self.maskStreaks.run(maskedImage)
maskedImage = difference.maskedImage
# Bin the diffim to enhance low surface brightness streaks
binnedMaskedImage = afwMath.binImage(maskedImage,
self.config.streakBinFactor,
self.config.streakBinFactor)
binnedExposure = afwImage.ExposureF(binnedMaskedImage.getBBox())
binnedExposure.setMaskedImage(binnedMaskedImage)
binnedExposure.setPsf(difference.psf) # exposure must have a PSF
# Rerun detection to set the DETECTED mask plane on binnedExposure
_table = afwTable.SourceTable.make(self.schema)
self.detection.run(table=_table, exposure=binnedExposure, doSmooth=True)
binnedDetectedMaskPlane = binnedExposure.mask.array & binnedExposure.mask.getPlaneBitMask('DETECTED')
rescaledDetectedMaskPlane = binnedDetectedMaskPlane.repeat(self.config.streakBinFactor,
axis=0).repeat(self.config.streakBinFactor,
axis=1)
# Create new version of a diffim with DETECTED based on binnedExposure
streakMaskedImage = maskedImage.clone()
ysize, xsize = rescaledDetectedMaskPlane.shape
streakMaskedImage.mask.array[:ysize, :xsize] |= rescaledDetectedMaskPlane
# Detect streaks on this new version of the diffim
streaks = self.maskStreaks.run(streakMaskedImage)
streakMaskPlane = streakMaskedImage.mask.array & streakMaskedImage.mask.getPlaneBitMask('STREAK')
# Apply the new STREAK mask to the original diffim
maskedImage.mask.array |= streakMaskPlane

if self.config.writeStreakInfo:
rhos = np.array([line.rho for line in streaks.lines])
thetas = np.array([line.theta for line in streaks.lines])
Expand Down
34 changes: 33 additions & 1 deletion tests/test_detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import unittest

import lsst.afw.geom as afwGeom
import lsst.afw.image as afwImage
import lsst.afw.math as afwMath
import lsst.geom
Expand Down Expand Up @@ -570,7 +571,9 @@ def test_mask_streaks(self):
noiseLevel = 1.
staticSeed = 1
fluxLevel = 500
kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel}
xSize = 400
ySize = 400
kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel, "xSize": xSize, "ySize": ySize}
science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, **kwargs)
matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs)

Expand All @@ -594,6 +597,35 @@ def test_mask_streaks(self):
streakMaskSet = (outMask & streakMask) > 0
self.assertTrue(np.all(streakMaskSet[20:23, 40:200]))

# Add a more trapezoid shaped streak across an image that is
# fainter and check that it is also detected
bbox = science.getBBox()
difference = science.clone()
difference.maskedImage -= matchedTemplate.maskedImage
width = 100
amplitude = 5
x0 = -100 + bbox.getBeginX()
y0 = -100 + bbox.getBeginY()
x1 = xSize + x0 + 100
y1 = ySize/2
corner_coords = [lsst.geom.Point2D(x0, y0),
lsst.geom.Point2D(x0, y0 + width),
lsst.geom.Point2D(x1, y1),
lsst.geom.Point2D(x1 + width, y1)]
streak_trapezoid = afwGeom.Polygon(corner_coords)
streak_image = streak_trapezoid.createImage(bbox)
streak_image *= amplitude
difference.image.array += streak_image.array
output = detectionTask.run(science, matchedTemplate, difference)
outMask = output.subtractedMeasuredExposure.mask.array
streakMask = output.subtractedMeasuredExposure.mask.getPlaneBitMask("STREAK")
streakMaskSet = (outMask & streakMask) > 0
# Check that pixel values in streak_image equal the streak amplitude
streak_check = np.where(streak_image.array == amplitude)
self.assertTrue(np.all(streakMaskSet[streak_check]))
# Check that the entire image was not masked STREAK
self.assertFalse(np.all(streakMaskSet))


class DetectAndMeasureScoreTest(DetectAndMeasureTestBase, lsst.utils.tests.TestCase):
detectionTask = detectAndMeasure.DetectAndMeasureScoreTask
Expand Down

0 comments on commit fec2e1e

Please sign in to comment.