Skip to content

Commit d4be0fe

Browse files
committed
DM-51792 changes
1 parent 60c8b41 commit d4be0fe

File tree

1 file changed

+128
-30
lines changed

1 file changed

+128
-30
lines changed

python/lsst/pipe/tasks/measurementDriver.py

Lines changed: 128 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import lsst.pipe.base as pipeBase
4747
import lsst.scarlet.lite as scl
4848
import numpy as np
49+
from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask
4950
from lsst.pex.config import Config, ConfigurableField, Field
5051

5152
logging.basicConfig(level=logging.INFO)
@@ -126,6 +127,12 @@ def validate(self):
126127
if not any(getattr(self, opt) for opt in self.doOptions):
127128
raise ValueError(f"At least one of these options must be enabled: {self.doOptions}")
128129

130+
if self.doApCorr and not self.doMeasure:
131+
raise ValueError("Aperture correction requires measurement to be enabled.")
132+
133+
if self.doRunCatalogCalculation and not self.doMeasure:
134+
raise ValueError("Catalog calculation requires measurement to be enabled.")
135+
129136

130137
class MeasurementDriverBaseTask(pipeBase.Task, metaclass=ABCMeta):
131138
"""Base class for the mid-level driver running variance scaling, detection,
@@ -260,13 +267,16 @@ def _initializeSchema(self, catalog: afwTable.SourceCatalog = None):
260267
# Since a catalog is provided, use its Schema as the base.
261268
catalogSchema = catalog.schema
262269

263-
# Ensure that the Schema has coordinate error fields.
264-
self._addCoordErrorFieldsIfMissing(catalogSchema)
265-
266270
# Create a SchemaMapper that maps from catalogSchema to a new one
267271
# it will create.
268272
self.mapper = afwTable.SchemaMapper(catalogSchema)
269273

274+
# Ensure coordinate error fields exist in the schema.
275+
# This must be done after initializing the SchemaMapper to avoid
276+
# unequal schemas between input record and mapper during the
277+
# _updateCatalogSchema call.
278+
self._addCoordErrorFieldsIfMissing(catalogSchema)
279+
270280
# Add everything from catalogSchema to output Schema.
271281
self.mapper.addMinimalSchema(catalogSchema, True)
272282

@@ -978,6 +988,7 @@ def __init__(self, *args, **kwargs):
978988
def run(
979989
self,
980990
mExposure: afwImage.MultibandExposure | list[afwImage.Exposure] | afwImage.Exposure,
991+
mDeconvolved: afwImage.MultibandExposure | list[afwImage.Exposure] | afwImage.Exposure | None = None,
981992
refBand: str | None = None,
982993
bands: list[str] | None = None,
983994
catalog: afwTable.SourceCatalog = None,
@@ -989,11 +1000,17 @@ def run(
9891000
Parameters
9901001
----------
9911002
mExposure :
992-
Multi-band data. May be a `MultibandExposure`, a single-band
993-
exposure (i.e., `Exposure`), or a list of single-band exposures
994-
associated with different bands in which case ``bands`` must be
995-
provided. If a single-band exposure is given, it will be treated as
996-
a `MultibandExposure` that contains only that one band.
1003+
Multi-band data containing images of the same shape and region of
1004+
the sky. May be a `MultibandExposure`, a single-band exposure
1005+
(i.e., `Exposure`), or a list of single-band exposures associated
1006+
with different bands in which case ``bands`` must be provided. If a
1007+
single-band exposure is given, it will be treated as a
1008+
`MultibandExposure` that contains only that one band.
1009+
mDeconvolved :
1010+
Multi-band deconvolved images of the same shape and region of the
1011+
sky. Follows the same type conventions as ``mExposure``. If not
1012+
provided, the deblender will run the deconvolution internally
1013+
using the provided ``mExposure``.
9971014
refBand :
9981015
Reference band to use for detection. Not required for single-band
9991016
exposures. If `measureOnlyInRefBand` is enabled while detection is
@@ -1034,7 +1051,9 @@ def run(
10341051
"""
10351052

10361053
# Validate inputs and adjust them as necessary.
1037-
mExposure, refBand, bands = self._ensureValidInputs(mExposure, refBand, bands, catalog)
1054+
mExposure, mDeconvolved, refBand, bands = self._ensureValidInputs(
1055+
mExposure, mDeconvolved, refBand, bands, catalog
1056+
)
10381057

10391058
# Prepare the Schema and subtasks for processing.
10401059
catalog = self._prepareSchemaAndSubtasks(catalog)
@@ -1059,7 +1078,7 @@ def run(
10591078

10601079
# Deblend detected sources and update the catalog(s).
10611080
if self.config.doDeblend:
1062-
catalogs, self.modelData = self._deblendSources(mExposure, catalog, refBand=refBand)
1081+
catalogs, self.modelData = self._deblendSources(mExposure, mDeconvolved, catalog, refBand=refBand)
10631082
else:
10641083
self.log.warning(
10651084
"Skipping deblending; proceeding with the provided catalog in the reference band"
@@ -1078,6 +1097,7 @@ def run(
10781097
def _ensureValidInputs(
10791098
self,
10801099
mExposure: afwImage.MultibandExposure | list[afwImage.Exposure] | afwImage.Exposure,
1100+
mDeconvolved: afwImage.MultibandExposure | list[afwImage.Exposure] | afwImage.Exposure | None,
10811101
refBand: str | None,
10821102
bands: list[str] | None,
10831103
catalog: afwTable.SourceCatalog | None = None,
@@ -1089,6 +1109,8 @@ def _ensureValidInputs(
10891109
----------
10901110
mExposure :
10911111
Multi-band data to be processed by the driver task.
1112+
mDeconvolved :
1113+
Multi-band deconvolved data to be processed by the driver task.
10921114
refBand :
10931115
Reference band to use for detection or measurements.
10941116
bands :
@@ -1100,6 +1122,8 @@ def _ensureValidInputs(
11001122
-------
11011123
mExposure :
11021124
Multi-band exposure to be processed by the driver task.
1125+
mDeconvolved :
1126+
Multi-band deconvolved exposure to be processed by the driver task.
11031127
refBand :
11041128
Reference band to use for detection or measurements.
11051129
bands :
@@ -1110,25 +1134,47 @@ def _ensureValidInputs(
11101134
super()._ensureValidInputs(catalog)
11111135

11121136
# Multi-band-specific validation and adjustments.
1137+
1138+
# Validate mExposure.
11131139
if isinstance(mExposure, afwImage.MultibandExposure):
11141140
if bands is not None:
11151141
if any(b not in mExposure.bands for b in bands):
11161142
raise ValueError(
11171143
"Some bands in the 'bands' list are not present in the input multi-band exposure"
11181144
)
11191145
self.log.info(
1120-
f"Using bands {bands} out of the available {mExposure.bands} in the multi-band exposure"
1146+
f"Using bands {bands} out of the available {mExposure.bands} in the multi-band exposures"
11211147
)
11221148
elif isinstance(mExposure, list):
11231149
if bands is None:
11241150
raise ValueError("The 'bands' list must be provided if 'mExposure' is a list")
11251151
if len(bands) != len(mExposure):
11261152
raise ValueError("Number of bands and exposures must match")
1127-
elif isinstance(mExposure, afwImage.Exposure):
1153+
elif not isinstance(mExposure, afwImage.Exposure):
1154+
raise TypeError(f"Unsupported 'mExposure' type: {type(mExposure)}")
1155+
1156+
# Validate mDeconvolved.
1157+
if mDeconvolved:
1158+
if isinstance(mDeconvolved, afwImage.MultibandExposure):
1159+
if bands is not None:
1160+
if any(b not in mDeconvolved.bands for b in bands):
1161+
raise ValueError(
1162+
"Some bands in the 'bands' list are not present in the input "
1163+
"multi-band deconvolved exposure"
1164+
)
1165+
elif isinstance(mDeconvolved, list):
1166+
if bands is None:
1167+
raise ValueError("The 'bands' list must be provided if 'mDeconvolved' is a list")
1168+
if len(bands) != len(mDeconvolved):
1169+
raise ValueError("Number of bands and deconvolved exposures must match")
1170+
elif not isinstance(mDeconvolved, afwImage.Exposure):
1171+
raise TypeError(f"Unsupported 'mDeconvolved' type: {type(mDeconvolved)}")
1172+
1173+
if isinstance(mExposure, afwImage.Exposure) or isinstance(mDeconvolved, afwImage.Exposure):
11281174
if bands is not None and len(bands) != 1:
11291175
raise ValueError(
11301176
"The 'bands' list, if provided, must only contain a single band "
1131-
"if a single-band exposure is given"
1177+
"if one of 'mExposure' or 'mDeconvolved' is a single-band exposure"
11321178
)
11331179
if bands is None and refBand is None:
11341180
refBand = "unknown" # Placeholder for single-band deblending
@@ -1137,13 +1183,19 @@ def _ensureValidInputs(
11371183
bands = [refBand]
11381184
elif bands is not None and refBand is None:
11391185
refBand = bands[0]
1140-
else:
1141-
raise TypeError(f"Unsupported 'mExposure' type: {type(mExposure)}")
11421186

1143-
# Convert mExposure to a MultibandExposure object with the bands
1144-
# provided.
1187+
# Convert or subset the exposures to a MultibandExposure with the
1188+
# bands of interest.
11451189
mExposure = self._buildMultibandExposure(mExposure, bands)
11461190

1191+
if mDeconvolved:
1192+
mDeconvolved = self._buildMultibandExposure(mDeconvolved, bands)
1193+
if mExposure.bands != mDeconvolved.bands:
1194+
raise ValueError(
1195+
"The bands in 'mExposure' and 'mDeconvolved' must match; "
1196+
f"got {mExposure.bands} and {mDeconvolved.bands}"
1197+
)
1198+
11471199
if len(mExposure.bands) == 1:
11481200
# N.B. Scarlet is designed to leverage multi-band information to
11491201
# differentiate overlapping sources based on their spectral and
@@ -1172,22 +1224,30 @@ def _ensureValidInputs(
11721224
raise ValueError("Reference band must be provided for multi-band data")
11731225

11741226
if refBand not in mExposure.bands:
1175-
raise ValueError(f"Requested band '{refBand}' is not present in the multi-band exposure")
1227+
raise ValueError(f"Requested band '{refBand}' is not present in the multi-band exposures")
11761228

11771229
if bands is not None and refBand not in bands:
11781230
raise ValueError(f"Reference band '{refBand}' is not in the list of 'bands' provided: {bands}")
11791231

1180-
return mExposure, refBand, bands
1232+
return mExposure, mDeconvolved, refBand, bands
11811233

11821234
def _deblendSources(
1183-
self, mExposure: afwImage.MultibandExposure, catalog: afwTable.SourceCatalog, refBand: str
1235+
self,
1236+
mExposure: afwImage.MultibandExposure,
1237+
mDeconvolved: afwImage.MultibandExposure | None,
1238+
catalog: afwTable.SourceCatalog,
1239+
refBand: str,
11841240
) -> tuple[dict[str, afwTable.SourceCatalog], scl.io.ScarletModelData]:
11851241
"""Run multi-band deblending given a multi-band exposure and a catalog.
11861242
11871243
Parameters
11881244
----------
11891245
mExposure :
11901246
Multi-band exposure on which to run the deblending algorithm.
1247+
mDeconvolved :
1248+
Multi-band deconvolved exposure to use for deblending. If None,
1249+
the deblender will create it internally using the provided
1250+
``mExposure``.
11911251
catalog :
11921252
Catalog containing sources to be deblended.
11931253
refBand :
@@ -1206,8 +1266,19 @@ def _deblendSources(
12061266
"""
12071267
self.log.info(f"Deblending using '{self._Deblender}' on {len(catalog)} detection footprints")
12081268

1269+
if mDeconvolved is None:
1270+
# Make a deconvolve version of the multi-band exposure.
1271+
deconvolvedCoadds = []
1272+
deconvolveTask = DeconvolveExposureTask()
1273+
for coadd in mExposure:
1274+
deconvolvedCoadd = deconvolveTask.run(coadd, catalog).deconvolved
1275+
deconvolvedCoadds.append(deconvolvedCoadd)
1276+
mDeconvolved = afwImage.MultibandExposure.fromExposures(mExposure.bands, deconvolvedCoadds)
1277+
12091278
# Run the deblender on the multi-band exposure.
1210-
catalog, modelData = self.deblend.run(mExposure, catalog)
1279+
result = self.deblend.run(mExposure, mDeconvolved, catalog)
1280+
catalog = result.deblendedCatalog
1281+
modelData = result.scarletModelData
12111282

12121283
# Determine which bands to process post-deblending.
12131284
bands = [refBand] if self.config.measureOnlyInRefBand else mExposure.bands
@@ -1319,10 +1390,7 @@ def _validate(self):
13191390
"deblending; set doDetect=False and doDeblend=False"
13201391
)
13211392
if not self.doMeasure:
1322-
raise ValueError(
1323-
"ForcedMeasurementDriverTask must perform measurements; "
1324-
"set doMeasure=True"
1325-
)
1393+
raise ValueError("ForcedMeasurementDriverTask must perform measurements; set doMeasure=True")
13261394

13271395

13281396
class ForcedMeasurementDriverTask(SingleBandMeasurementDriverTask):
@@ -1428,7 +1496,7 @@ def runFromAstropy(
14281496
additional measurement columns defined in the configuration.
14291497
"""
14301498
# Validate inputs before proceeding.
1431-
self._ensureValidInputs(table, exposure, id_column_name, ra_column_name, dec_column_name)
1499+
coord_unit = self._ensureValidInputs(table, exposure, id_column_name, ra_column_name, dec_column_name)
14321500

14331501
# Generate catalog IDs consistently across subtasks.
14341502
if idGenerator is None:
@@ -1445,7 +1513,7 @@ def runFromAstropy(
14451513
# This must be done *after* `_prepareSchemaAndSubtasks`, or the schema
14461514
# won't be set up correctly.
14471515
refCat = self._makeMinimalSourceCatalogFromAstropy(
1448-
table, columns=[id_column_name, ra_column_name, dec_column_name]
1516+
table, columns=[id_column_name, ra_column_name, dec_column_name], coord_unit=coord_unit
14491517
)
14501518

14511519
# Check whether coords are within the image.
@@ -1521,19 +1589,40 @@ def _ensureValidInputs(
15211589
Name of the column containing RA coordinates in the table.
15221590
dec_column_name :
15231591
Name of the column containing Dec coordinates in the table.
1592+
1593+
Returns
1594+
-------
1595+
coord_unit : `str`
1596+
Unit of the sky coordinates extracted from the table.
15241597
"""
15251598
if not isinstance(table, astropy.table.Table):
15261599
raise TypeError(f"Expected 'table' to be an astropy Table, got {type(table)}")
15271600

1601+
if table[ra_column_name].unit == table[dec_column_name].unit:
1602+
if table[ra_column_name].unit == astropy.units.deg:
1603+
coord_unit = "degrees"
1604+
elif table[ra_column_name].unit == astropy.units.rad:
1605+
coord_unit = "radians"
1606+
else:
1607+
# Fallback if it's something else.
1608+
coord_unit = str(table[ra_column_name].unit)
1609+
else:
1610+
raise ValueError("RA and Dec columns must have the same unit")
1611+
15281612
if not isinstance(exposure, afwImage.Exposure):
15291613
raise TypeError(f"Expected 'exposure' to be an Exposure, got {type(exposure)}")
15301614

15311615
for col in [id_column_name, ra_column_name, dec_column_name]:
15321616
if col not in table.colnames:
15331617
raise ValueError(f"Column '{col}' not found in the input table")
15341618

1619+
return coord_unit
1620+
15351621
def _makeMinimalSourceCatalogFromAstropy(
1536-
self, table: astropy.table.Table, columns: list[str] = ["id", "ra", "dec"]
1622+
self,
1623+
table: astropy.table.Table,
1624+
columns: list[str] = ["id", "ra", "dec"],
1625+
coord_unit: str = "degrees",
15371626
):
15381627
"""Convert an Astropy Table to a minimal LSST SourceCatalog.
15391628
@@ -1547,7 +1636,10 @@ def _makeMinimalSourceCatalogFromAstropy(
15471636
Astropy Table containing source IDs and sky coordinates.
15481637
columns :
15491638
Names of the columns in the order [id, ra, dec], where `ra` and
1550-
`dec` are in degrees.
1639+
`dec` are in degrees by default. If the coordinates are in radians,
1640+
set `coord_unit` to "radians".
1641+
coord_unit : `str`
1642+
Unit of the sky coordinates. Can be either "degrees" or "radians".
15511643
15521644
Returns
15531645
-------
@@ -1557,6 +1649,7 @@ def _makeMinimalSourceCatalogFromAstropy(
15571649
Raises
15581650
------
15591651
ValueError
1652+
If `coord_unit` is not "degrees" or "radians".
15601653
If `columns` does not contain exactly 3 items.
15611654
KeyError
15621655
If any of the specified columns are missing from the input table.
@@ -1565,6 +1658,9 @@ def _makeMinimalSourceCatalogFromAstropy(
15651658
# the configs, and move this from being a Task method to a free
15661659
# function that takes column names as args.
15671660

1661+
if coord_unit not in ["degrees", "radians"]:
1662+
raise ValueError(f"Invalid coordinate unit '{coord_unit}'; must be 'degrees' or 'radians'")
1663+
15681664
if len(columns) != 3:
15691665
raise ValueError("`columns` must contain exactly three elements for [id, ra, dec]")
15701666

@@ -1580,6 +1676,8 @@ def _makeMinimalSourceCatalogFromAstropy(
15801676
for row in table:
15811677
outputRecord = outputCatalog.addNew()
15821678
outputRecord.setId(row[idCol])
1583-
outputRecord.setCoord(lsst.geom.SpherePoint(row[raCol], row[decCol], lsst.geom.degrees))
1679+
outputRecord.setCoord(
1680+
lsst.geom.SpherePoint(row[raCol], row[decCol], getattr(lsst.geom, coord_unit))
1681+
)
15841682

15851683
return outputCatalog

0 commit comments

Comments
 (0)