Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-46050: Add option to use image average PSF for alert cutouts. #233

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/lsst/ap/association/diaForcedSource.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def _calibrate_and_merge(self,
"slot_PsfFlux")

output_catalog = diff_sources.asAstropy().to_pandas()
# afwTable source catalogs store coordinates as radians, but the
# output must be in degrees
output_catalog.loc[:, "ra"] = np.rad2deg(output_catalog.loc[:, "ra"])
output_catalog.loc[:, "dec"] = np.rad2deg(output_catalog.loc[:, "dec"])
output_catalog.rename(columns={"id": "diaForcedSourceId",
"slot_PsfFlux_instFlux": "psfFlux",
"slot_PsfFlux_instFluxErr": "psfFluxErr",
Expand Down
2 changes: 0 additions & 2 deletions python/lsst/ap/association/diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,6 @@ def run(self,
exposure,
diffIm,
idGenerator=idGenerator)
# columns "ra" and "dec" are required for spatial sharding in Cassandra
diaForcedSources.rename(columns={"coord_ra": "ra", "coord_dec": "dec"}, inplace=True)
else:
# alertPackager needs correct columns
diaForcedSources = pd.DataFrame(columns=[
Expand Down
87 changes: 69 additions & 18 deletions python/lsst/ap/association/packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ class PackageAlertsConfig(pexConfig.Config):
default=1200.0,
)

useAveragePsf = pexConfig.Field(
dtype=bool,
doc="Use the average PSF for the image, instead of the PSF for each cutout. "
"This option is much less accurate, but much faster.",
default=False,
)


class PackageAlertsTask(pipeBase.Task):
"""Tasks for packaging Dia and Pipelines data into Avro alert packages.
Expand Down Expand Up @@ -248,6 +255,9 @@ def run(self,
diffImPhotoCalib = diffIm.getPhotoCalib()
calexpPhotoCalib = calexp.getPhotoCalib()
templatePhotoCalib = template.getPhotoCalib()
diffImPsf = self._computePsf(diffIm, diffIm.psf.getAveragePosition())
sciencePsf = self._computePsf(calexp, calexp.psf.getAveragePosition())
templatePsf = self._computePsf(template, template.psf.getAveragePosition())

n_sources = len(diaSourceCat)
self.log.info("Packaging alerts for %d DiaSources.", n_sources)
Expand Down Expand Up @@ -284,21 +294,24 @@ def run(self,
pixelPoint,
cutoutExtent,
diffImPhotoCalib,
diaSource["diaSourceId"])
diaSource["diaSourceId"],
averagePsf=diffImPsf)
calexpCutout = self.createCcdDataCutout(
calexp,
sphPoint,
pixelPoint,
cutoutExtent,
calexpPhotoCalib,
diaSource["diaSourceId"])
diaSource["diaSourceId"],
averagePsf=sciencePsf)
templateCutout = self.createCcdDataCutout(
template,
sphPoint,
pixelPoint,
cutoutExtent,
templatePhotoCalib,
diaSource["diaSourceId"])
diaSource["diaSourceId"],
averagePsf=templatePsf)

# TODO: Create alertIds DM-24858
alertId = diaSource["diaSourceId"]
Expand Down Expand Up @@ -389,7 +402,7 @@ def produceAlerts(self, alerts, visit, detector):

self.producer.flush()

def createCcdDataCutout(self, image, skyCenter, pixelCenter, extent, photoCalib, srcId):
def createCcdDataCutout(self, image, skyCenter, pixelCenter, extent, photoCalib, srcId, averagePsf=None):
"""Grab an image as a cutout and return a calibrated CCDData image.

Parameters
Expand All @@ -407,6 +420,9 @@ def createCcdDataCutout(self, image, skyCenter, pixelCenter, extent, photoCalib,
srcId : `int`
Unique id of DiaSource. Used for when an error occurs extracting
a cutout.
averagePsf : `numpy.array`, optional
Average PSF to attach to the cutout.
Used if ``self.config.useAveragePsf`` is set.

Returns
-------
Expand All @@ -427,24 +443,18 @@ def createCcdDataCutout(self, image, skyCenter, pixelCenter, extent, photoCalib,
try:
cutout = image.getCutout(pixelCenter, extent)
except InvalidParameterError:
raise InvalidParameterError(
self.log.warning(
"Failed to retrieve cutout from image for DiaSource with "
"id=%i. InvalidParameterError thrown during cutout "
"creation. Returning None for cutout..."
% srcId)
try:
# use image.psf.computeKernelImage to provide PSF centered in the array
cutoutPsf = image.psf.computeKernelImage(pixelCenter).array
except InvalidParameterError:
self.log.warning("Could not calculate PSF for DiaSource with "
"id=%i. InvalidParameterError encountered. Exiting."
% srcId)
cutoutPsf = None
except InvalidPsfError:
self.log.warning("Could not calculate PSF for DiaSource with "
"id=%i. InvalidPsfError encountered. Exiting."
% srcId)
cutoutPsf = None
if self.config.useAveragePsf:
if averagePsf is None:
self.log.info("Using source id=%i to compute the average PSF.", srcId)
averagePsf = self._computePsf(image, pixelCenter, srcId=srcId)
cutoutPsf = averagePsf
else:
cutoutPsf = self._computePsf(image, pixelCenter, srcId=srcId)

# Find the value of the bottom corner of our cutout's BBox and
# subtract 1 so that the CCDData cutout position value will be
Expand Down Expand Up @@ -667,3 +677,44 @@ def _server_check(self):

if not topics:
raise RuntimeError()

def _computePsf(self, exposure, pixelCenter, srcId=None):
"""Compute the PSF at a location and catch errors.

Parameters
----------
exposure : `lsst.afw.image.Exposure`
The image to compute the PSF for.
pixelCenter : `lsst.geom.Point2D`
The location on the image to compute the PSF.
srcId : `int`, optional
Unique id of DiaSource. Used for when an error occurs extracting
a cutout.

Returns
-------
cutoutPsf : `numpy.array`
Array of the PSF values.
"""
try:
# use exposure.psf.computeKernelImage to provide PSF centered in the array
cutoutPsf = exposure.psf.computeKernelImage(pixelCenter).array
except InvalidParameterError:
if srcId is not None:
msg = "Could not calculate PSF for DiaSource with "\
"id=%i. InvalidParameterError encountered. Exiting."\
% srcId
else:
msg = "Could not calculate average PSF for the image"
self.log.warning(msg)
cutoutPsf = None
except InvalidPsfError:
if srcId is not None:
msg = "Could not calculate PSF for DiaSource with "\
"id=%i. InvalidPsfError encountered. Exiting."\
% srcId
else:
msg = "Could not calculate average PSF for the image"
self.log.warning(msg)
cutoutPsf = None
return cutoutPsf
75 changes: 70 additions & 5 deletions tests/test_packageAlerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,21 @@ def _roundTripThroughApdb(objects, sources, forcedSources, dateTime):
apdb = Apdb.from_config(apdbConfig)

wholeSky = Box.full()
diaObjects = pd.concat([apdb.getDiaObjects(wholeSky), objects])
diaSources = pd.concat(
[apdb.getDiaSources(wholeSky, [], dateTime), sources])
diaForcedSources = pd.concat(
[apdb.getDiaForcedSources(wholeSky, [], dateTime), forcedSources])
loadedObjects = apdb.getDiaObjects(wholeSky)
if loadedObjects.empty:
diaObjects = objects
else:
diaObjects = pd.concat([loadedObjects, objects])
loadedDiaSources = apdb.getDiaSources(wholeSky, [], dateTime)
if loadedDiaSources.empty:
diaSources = sources
else:
diaSources = pd.concat([loadedDiaSources, sources])
loadedDiaForcedSources = apdb.getDiaForcedSources(wholeSky, [], dateTime)
if loadedDiaForcedSources.empty:
diaForcedSources = forcedSources
else:
diaForcedSources = pd.concat([loadedDiaForcedSources, forcedSources])

apdb.store(dateTime, diaObjects, diaSources, diaForcedSources)

Expand Down Expand Up @@ -529,6 +539,61 @@ def testRun_without_produce(self, mock_server_check):
self.assertEqual(alert["cutoutDifference"],
packageAlerts.streamCcdDataToBytes(ccdCutout))

@patch.object(PackageAlertsTask, '_server_check')
def testRun_without_produce_use_averagePsf(self, mock_server_check):
"""Test the run method of package alerts with produce set to False and
doWriteAlerts set to true.
"""
packConfig = PackageAlertsConfig(doWriteAlerts=True)
with tempfile.TemporaryDirectory(prefix='alerts') as tempdir:
packConfig.alertWriteLocation = tempdir
packConfig.useAveragePsf = True
packageAlerts = PackageAlertsTask(config=packConfig)

packageAlerts.run(self.diaSources,
self.diaObjects,
self.diaSourceHistory,
self.diaForcedSources,
self.exposure,
self.exposure,
self.exposure)

self.assertEqual(mock_server_check.call_count, 0)

with open(os.path.join(tempdir, f"{VISIT}_{DETECTOR}.avro"), 'rb') as f:
writer_schema, data_stream = \
packageAlerts.alertSchema.retrieve_alerts(f)
data = list(data_stream)

self.assertEqual(len(data), len(self.diaSources))
for idx, alert in enumerate(data):
for key, value in alert["diaSource"].items():
if isinstance(value, float):
if np.isnan(self.diaSources.iloc[idx][key]):
self.assertTrue(np.isnan(value))
else:
self.assertAlmostEqual(
1 - value / self.diaSources.iloc[idx][key],
0.)
else:
self.assertEqual(value, self.diaSources.iloc[idx][key])
sphPoint = geom.SpherePoint(alert["diaSource"]["ra"],
alert["diaSource"]["dec"],
geom.degrees)
pixelPoint = geom.Point2D(alert["diaSource"]["x"], alert["diaSource"]["y"])
cutout = self.exposure.getCutout(sphPoint,
geom.Extent2I(self.cutoutSize,
self.cutoutSize))
ccdCutout = packageAlerts.createCcdDataCutout(
cutout,
sphPoint,
pixelPoint,
geom.Extent2I(self.cutoutSize, self.cutoutSize),
cutout.getPhotoCalib(),
1234)
self.assertEqual(alert["cutoutDifference"],
packageAlerts.streamCcdDataToBytes(ccdCutout))

@patch.object(PackageAlertsTask, 'produceAlerts')
@patch('confluent_kafka.Producer')
@patch.object(PackageAlertsTask, '_server_check')
Expand Down
Loading