Skip to content

Commit

Permalink
Merge pull request #233 from lsst/tickets/DM-46050
Browse files Browse the repository at this point in the history
DM-46050: Add option to use image average PSF for alert cutouts.
  • Loading branch information
isullivan committed Sep 11, 2024
2 parents 37d7fc5 + 2eaa4b4 commit c3fbf1a
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 25 deletions.
4 changes: 4 additions & 0 deletions python/lsst/ap/association/diaForcedSource.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def _calibrate_and_merge(self,
"slot_PsfFlux")

output_catalog = diff_sources.asAstropy().to_pandas()
# afwTable source catalogs store coordinates as radians, but the
# output must be in degrees
output_catalog.loc[:, "ra"] = np.rad2deg(output_catalog.loc[:, "ra"])
output_catalog.loc[:, "dec"] = np.rad2deg(output_catalog.loc[:, "dec"])
output_catalog.rename(columns={"id": "diaForcedSourceId",
"slot_PsfFlux_instFlux": "psfFlux",
"slot_PsfFlux_instFluxErr": "psfFluxErr",
Expand Down
2 changes: 0 additions & 2 deletions python/lsst/ap/association/diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,6 @@ def run(self,
exposure,
diffIm,
idGenerator=idGenerator)
# columns "ra" and "dec" are required for spatial sharding in Cassandra
diaForcedSources.rename(columns={"coord_ra": "ra", "coord_dec": "dec"}, inplace=True)
else:
# alertPackager needs correct columns
diaForcedSources = pd.DataFrame(columns=[
Expand Down
87 changes: 69 additions & 18 deletions python/lsst/ap/association/packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ class PackageAlertsConfig(pexConfig.Config):
default=1200.0,
)

useAveragePsf = pexConfig.Field(
dtype=bool,
doc="Use the average PSF for the image, instead of the PSF for each cutout. "
"This option is much less accurate, but much faster.",
default=False,
)


class PackageAlertsTask(pipeBase.Task):
"""Tasks for packaging Dia and Pipelines data into Avro alert packages.
Expand Down Expand Up @@ -248,6 +255,9 @@ def run(self,
diffImPhotoCalib = diffIm.getPhotoCalib()
calexpPhotoCalib = calexp.getPhotoCalib()
templatePhotoCalib = template.getPhotoCalib()
diffImPsf = self._computePsf(diffIm, diffIm.psf.getAveragePosition())
sciencePsf = self._computePsf(calexp, calexp.psf.getAveragePosition())
templatePsf = self._computePsf(template, template.psf.getAveragePosition())

n_sources = len(diaSourceCat)
self.log.info("Packaging alerts for %d DiaSources.", n_sources)
Expand Down Expand Up @@ -284,21 +294,24 @@ def run(self,
pixelPoint,
cutoutExtent,
diffImPhotoCalib,
diaSource["diaSourceId"])
diaSource["diaSourceId"],
averagePsf=diffImPsf)
calexpCutout = self.createCcdDataCutout(
calexp,
sphPoint,
pixelPoint,
cutoutExtent,
calexpPhotoCalib,
diaSource["diaSourceId"])
diaSource["diaSourceId"],
averagePsf=sciencePsf)
templateCutout = self.createCcdDataCutout(
template,
sphPoint,
pixelPoint,
cutoutExtent,
templatePhotoCalib,
diaSource["diaSourceId"])
diaSource["diaSourceId"],
averagePsf=templatePsf)

# TODO: Create alertIds DM-24858
alertId = diaSource["diaSourceId"]
Expand Down Expand Up @@ -389,7 +402,7 @@ def produceAlerts(self, alerts, visit, detector):

self.producer.flush()

def createCcdDataCutout(self, image, skyCenter, pixelCenter, extent, photoCalib, srcId):
def createCcdDataCutout(self, image, skyCenter, pixelCenter, extent, photoCalib, srcId, averagePsf=None):
"""Grab an image as a cutout and return a calibrated CCDData image.
Parameters
Expand All @@ -407,6 +420,9 @@ def createCcdDataCutout(self, image, skyCenter, pixelCenter, extent, photoCalib,
srcId : `int`
Unique id of DiaSource. Used for when an error occurs extracting
a cutout.
averagePsf : `numpy.array`, optional
Average PSF to attach to the cutout.
Used if ``self.config.useAveragePsf`` is set.
Returns
-------
Expand All @@ -427,24 +443,18 @@ def createCcdDataCutout(self, image, skyCenter, pixelCenter, extent, photoCalib,
try:
cutout = image.getCutout(pixelCenter, extent)
except InvalidParameterError:
raise InvalidParameterError(
self.log.warning(
"Failed to retrieve cutout from image for DiaSource with "
"id=%i. InvalidParameterError thrown during cutout "
"creation. Returning None for cutout..."
% srcId)
try:
# use image.psf.computeKernelImage to provide PSF centered in the array
cutoutPsf = image.psf.computeKernelImage(pixelCenter).array
except InvalidParameterError:
self.log.warning("Could not calculate PSF for DiaSource with "
"id=%i. InvalidParameterError encountered. Exiting."
% srcId)
cutoutPsf = None
except InvalidPsfError:
self.log.warning("Could not calculate PSF for DiaSource with "
"id=%i. InvalidPsfError encountered. Exiting."
% srcId)
cutoutPsf = None
if self.config.useAveragePsf:
if averagePsf is None:
self.log.info("Using source id=%i to compute the average PSF.", srcId)
averagePsf = self._computePsf(image, pixelCenter, srcId=srcId)
cutoutPsf = averagePsf
else:
cutoutPsf = self._computePsf(image, pixelCenter, srcId=srcId)

# Find the value of the bottom corner of our cutout's BBox and
# subtract 1 so that the CCDData cutout position value will be
Expand Down Expand Up @@ -667,3 +677,44 @@ def _server_check(self):

if not topics:
raise RuntimeError()

def _computePsf(self, exposure, pixelCenter, srcId=None):
"""Compute the PSF at a location and catch errors.
Parameters
----------
exposure : `lsst.afw.image.Exposure`
The image to compute the PSF for.
pixelCenter : `lsst.geom.Point2D`
The location on the image to compute the PSF.
srcId : `int`, optional
Unique id of DiaSource. Used for when an error occurs extracting
a cutout.
Returns
-------
cutoutPsf : `numpy.array`
Array of the PSF values.
"""
try:
# use exposure.psf.computeKernelImage to provide PSF centered in the array
cutoutPsf = exposure.psf.computeKernelImage(pixelCenter).array
except InvalidParameterError:
if srcId is not None:
msg = "Could not calculate PSF for DiaSource with "\
"id=%i. InvalidParameterError encountered. Exiting."\
% srcId
else:
msg = "Could not calculate average PSF for the image"
self.log.warning(msg)
cutoutPsf = None
except InvalidPsfError:
if srcId is not None:
msg = "Could not calculate PSF for DiaSource with "\
"id=%i. InvalidPsfError encountered. Exiting."\
% srcId
else:
msg = "Could not calculate average PSF for the image"
self.log.warning(msg)
cutoutPsf = None
return cutoutPsf
75 changes: 70 additions & 5 deletions tests/test_packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,21 @@ def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
apdb = Apdb.from_config(apdbConfig)

wholeSky = Box.full()
diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
diaSources = pd.concat(
[apdb.getDiaSources(wholeSky, [], dateTime), sources])
diaForcedSources = pd.concat(
[apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
loadedObjects = apdb.getDiaObjects(wholeSky)
if loadedObjects.empty:
diaObjects = objects
else:
diaObjects = pd.concat([loadedObjects, objects])
loadedDiaSources = apdb.getDiaSources(wholeSky, [], dateTime)
if loadedDiaSources.empty:
diaSources = sources
else:
diaSources = pd.concat([loadedDiaSources, sources])
loadedDiaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime)
if loadedDiaForcedSources.empty:
diaForcedSources = forcedSources
else:
diaForcedSources = pd.concat([loadedDiaForcedSources, forcedSources])

apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)

Expand Down Expand Up @@ -529,6 +539,61 @@ def testRun_without_produce(self, mock_server_check):
self.assertEqual(alert["cutoutDifference"],
packageAlerts.streamCcdDataToBytes(ccdCutout))

@patch.object(PackageAlertsTask, '_server_check')
def testRun_without_produce_use_averagePsf(self, mock_server_check):
"""Test the run method of package alerts with produce set to False and
doWriteAlerts set to true.
"""
packConfig = PackageAlertsConfig(doWriteAlerts=True)
with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
packConfig.alertWriteLocation = tempdir
packConfig.useAveragePsf = True
packageAlerts = PackageAlertsTask(config=packConfig)

packageAlerts.run(self.diaSources,
self.diaObjects,
self.diaSourceHistory,
self.diaForcedSources,
self.exposure,
self.exposure,
self.exposure)

self.assertEqual(mock_server_check.call_count, 0)

with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
writer_schema, data_stream = \
packageAlerts.alertSchema.retrieve_alerts(f)
data = list(data_stream)

self.assertEqual(len(data), len(self.diaSources))
for idx, alert in enumerate(data):
for key, value in alert["diaSource"].items():
if isinstance(value, float):
if np.isnan(self.diaSources.iloc[idx][key]):
self.assertTrue(np.isnan(value))
else:
self.assertAlmostEqual(
1 - value / self.diaSources.iloc[idx][key],
0.)
else:
self.assertEqual(value, self.diaSources.iloc[idx][key])
sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
alert["diaSource"]["dec"],
geom.degrees)
pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"])
cutout = self.exposure.getCutout(sphPoint,
geom.Extent2I(self.cutoutSize,
self.cutoutSize))
ccdCutout = packageAlerts.createCcdDataCutout(
cutout,
sphPoint,
pixelPoint,
geom.Extent2I(self.cutoutSize, self.cutoutSize),
cutout.getPhotoCalib(),
1234)
self.assertEqual(alert["cutoutDifference"],
packageAlerts.streamCcdDataToBytes(ccdCutout))

@patch.object(PackageAlertsTask, 'produceAlerts')
@patch('confluent_kafka.Producer')
@patch.object(PackageAlertsTask, '_server_check')
Expand Down

0 comments on commit c3fbf1a

Please sign in to comment.