diff --git a/data/association-flag-map.yaml b/data/association-flag-map.yaml index 3e8176d0..af7edc42 100644 --- a/data/association-flag-map.yaml +++ b/data/association-flag-map.yaml @@ -102,4 +102,4 @@ columns: doc: Fake source template injection in source footprint - name: base_PixelFlags_flag_injected_templateCenter bit: 33 - doc: Fake source template injection center in source footprint + doc: Fake source template injection center in source footprint \ No newline at end of file diff --git a/python/lsst/ap/association/__init__.py b/python/lsst/ap/association/__init__.py index 16cd34da..fd84f9a7 100644 --- a/python/lsst/ap/association/__init__.py +++ b/python/lsst/ap/association/__init__.py @@ -25,6 +25,5 @@ from .loadDiaCatalogs import * from .packageAlerts import * from .diaPipe import * -from .trailedSourceFilter import * from .transformDiaSourceCatalog import * from .version import * \ No newline at end of file diff --git a/python/lsst/ap/association/association.py b/python/lsst/ap/association/association.py index d2bd043a..eb4e0e41 100644 --- a/python/lsst/ap/association/association.py +++ b/python/lsst/ap/association/association.py @@ -32,7 +32,6 @@ import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase from lsst.utils.timer import timeMethod -from .trailedSourceFilter import TrailedSourceFilterTask # Enforce an error for unsafe column/array value setting in pandas. pd.options.mode.chained_assignment = 'raise' @@ -49,19 +48,6 @@ class AssociationConfig(pexConfig.Config): default=1.0, ) - trailedSourceFilter = pexConfig.ConfigurableField( - target=TrailedSourceFilterTask, - doc="Subtask to remove long trailed sources based on catalog source " - "morphological measurements.", - ) - - doTrailedSourceFilter = pexConfig.Field( - doc="Run traildeSourceFilter to remove long trailed sources from " - "output catalog.", - dtype=bool, - default=True, - ) - class AssociationTask(pipeBase.Task): """Associate DIAOSources into existing DIAObjects. @@ -75,16 +61,10 @@ class AssociationTask(pipeBase.Task): ConfigClass = AssociationConfig _DefaultName = "association" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.config.doTrailedSourceFilter: - self.makeSubtask("trailedSourceFilter") - @timeMethod def run(self, diaSources, - diaObjects, - exposure_time=None): + diaObjects): """Associate the new DiaSources with existing DiaObjects. Parameters @@ -93,8 +73,6 @@ def run(self, New DIASources to be associated with existing DIAObjects. diaObjects : `pandas.DataFrame` Existing diaObjects from the Apdb. - exposure_time : `float`, optional - Exposure time from difference image. Returns ------- @@ -112,30 +90,15 @@ def run(self, matched to new DiaSources. (`int`) - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were not matched a new DiaSource. (`int`) - - ``longTrailedSources`` : DiaSources which have trail lengths - greater than max_trail_length/second*exposure_time. - (`pandas.DataFrame``) """ diaSources = self.check_dia_source_radec(diaSources) - if self.config.doTrailedSourceFilter: - diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time) - diaSources = diaTrailedResult.diaSources - longTrailedSources = diaTrailedResult.longTrailedDiaSources - - self.log.info("%i DiaSources exceed max_trail_length, dropping from source " - "catalog." % len(diaTrailedResult.longTrailedDiaSources)) - self.metadata.add("num_filtered", len(diaTrailedResult.longTrailedDiaSources)) - else: - longTrailedSources = pd.DataFrame(columns=diaSources.columns) - if len(diaObjects) == 0: return pipeBase.Struct( matchedDiaSources=pd.DataFrame(columns=diaSources.columns), unAssocDiaSources=diaSources, nUpdatedDiaObjects=0, - nUnassociatedDiaObjects=0, - longTrailedSources=longTrailedSources) + nUnassociatedDiaObjects=0) matchResult = self.associate_sources(diaObjects, diaSources) @@ -145,8 +108,7 @@ def run(self, matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True), unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True), nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects, - nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects, - longTrailedSources=longTrailedSources) + nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects) def check_dia_source_radec(self, dia_sources): """Check that all DiaSources have non-NaN values for RA/DEC. diff --git a/python/lsst/ap/association/diaPipe.py b/python/lsst/ap/association/diaPipe.py index 9ee2d556..3aacd4b3 100644 --- a/python/lsst/ap/association/diaPipe.py +++ b/python/lsst/ap/association/diaPipe.py @@ -35,23 +35,22 @@ import warnings -import numpy as np -import pandas as pd - -from lsst.daf.base import DateTime import lsst.dax.apdb as daxApdb -from lsst.meas.base import DetectorVisitIdGeneratorConfig, DiaObjectCalculationTask import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase import lsst.pipe.base.connectionTypes as connTypes -from lsst.utils.timer import timeMethod - +import numpy as np +import pandas as pd from lsst.ap.association import ( AssociationTask, DiaForcedSourceTask, LoadDiaCatalogsTask, PackageAlertsTask) from lsst.ap.association.ssoAssociation import SolarSystemAssociationTask +from lsst.daf.base import DateTime +from lsst.meas.base import DetectorVisitIdGeneratorConfig, \ + DiaObjectCalculationTask +from lsst.utils.timer import timeMethod class DiaPipelineConnections( @@ -119,12 +118,6 @@ class DiaPipelineConnections( storageClass="DataFrame", dimensions=("instrument", "visit", "detector"), ) - longTrailedSources = pipeBase.connectionTypes.Output( - doc="Optional output temporarily storing long trailed diaSources.", - dimensions=("instrument", "visit", "detector"), - storageClass="DataFrame", - name="{fakesType}{coaddName}Diff_longTrailedSrc", - ) def __init__(self, *, config=None): super().__init__(config=config) @@ -137,8 +130,6 @@ def __init__(self, *, config=None): self.outputs.remove("diaForcedSources") if not config.doSolarSystemAssociation: self.inputs.remove("solarSystemObjectTable") - if not config.associator.doTrailedSourceFilter: - self.outputs.remove("longTrailedSources") def adjustQuantum(self, inputs, outputs, label, dataId): """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects @@ -438,10 +429,8 @@ def run(self, buffer=self.config.imagePixelMargin) else: diaObjects = loaderResult.diaObjects - # Associate new DiaSources with existing DiaObjects. - assocResults = self.associator.run(diaSourceTable, diaObjects, - exposure_time=diffIm.visitInfo.exposureTime) + assocResults = self.associator.run(diaSourceTable, diaObjects) if self.config.doSolarSystemAssociation: ssoAssocResult = self.solarSystemAssociator.run( @@ -627,7 +616,6 @@ def run(self, associatedDiaSources=associatedDiaSources, diaForcedSources=diaForcedSources, diaObjects=diaObjects, - longTrailedSources=assocResults.longTrailedSources ) def createNewDiaObjects(self, unAssocDiaSources): diff --git a/python/lsst/ap/association/filterDiaSourceCatalog.py b/python/lsst/ap/association/filterDiaSourceCatalog.py index 9330e1e9..b6b985f7 100644 --- a/python/lsst/ap/association/filterDiaSourceCatalog.py +++ b/python/lsst/ap/association/filterDiaSourceCatalog.py @@ -61,10 +61,28 @@ class FilterDiaSourceCatalogConnections( dimensions={"instrument", "visit", "detector"}, ) + diffImVisitInfo = connTypes.Input( + doc="VisitInfo of diffIm.", + name="{fakesType}{coaddName}Diff_differenceExp.visitInfo", + storageClass="VisitInfo", + dimensions=("instrument", "visit", "detector"), + ) + + longTrailedSources = connTypes.Output( + doc="Optional output temporarily storing long trailed diaSources.", + dimensions=("instrument", "visit", "detector"), + storageClass="ArrowAstropy", + name="{fakesType}{coaddName}Diff_longTrailedSrc", + ) + def __init__(self, *, config=None): super().__init__(config=config) - if not self.config.doWriteRejectedSources: + if not self.config.doWriteRejectedSkySources: self.outputs.remove("rejectedDiaSources") + if not self.config.doTrailedSourceFilter: + self.outputs.remove("longTrailedSources") + if not self.config.doWriteTrailedSources: + self.outputs.remove("longTrailedSources") class FilterDiaSourceCatalogConfig( @@ -79,13 +97,37 @@ class FilterDiaSourceCatalogConfig( "removed before storing the output DiaSource catalog.", ) - doWriteRejectedSources = pexConfig.Field( + doWriteRejectedSkySources = pexConfig.Field( dtype=bool, default=True, doc="Store the output DiaSource catalog containing all the rejected " "sky sources." ) + doTrailedSourceFilter = pexConfig.Field( + doc="Run trailedSourceFilter to remove long trailed sources from the" + "diaSource output catalog.", + dtype=bool, + default=True, + ) + + doWriteTrailedSources = pexConfig.Field( + doc="Write trailed diaSources sources to a table.", + dtype=bool, + default=True, + deprecated="Trailed sources will not be written out during production." + ) + + max_trail_length = pexConfig.Field( + dtype=float, + doc="Length of long trailed sources to remove from the input catalog, " + "in arcseconds per second. Default comes from DMTN-199, which " + "requires removal of sources with trails longer than 10 " + "degrees/day, which is 36000/3600/24 arcsec/second, or roughly" + "0.416 arcseconds per second.", + default=36000/3600.0/24.0, + ) + class FilterDiaSourceCatalogTask(pipeBase.PipelineTask): """Filter out sky sources from a DiaSource catalog.""" @@ -94,13 +136,15 @@ class FilterDiaSourceCatalogTask(pipeBase.PipelineTask): _DefaultName = "filterDiaSourceCatalog" @timeMethod - def run(self, diaSourceCat): + def run(self, diaSourceCat, diffImVisitInfo): """Filter sky sources from the supplied DiaSource catalog. Parameters ---------- diaSourceCat : `lsst.afw.table.SourceCatalog` Catalog of sources measured on the difference image. + diffImVisitInfo: `lsst.afw.image.VisitInfo` + VisitInfo for the difference image corresponding to diaSourceCat. Returns ------- @@ -109,9 +153,13 @@ def run(self, diaSourceCat): ``filteredDiaSourceCat`` : `lsst.afw.table.SourceCatalog` The catalog of filtered sources. ``rejectedDiaSources`` : `lsst.afw.table.SourceCatalog` - The catalog of rejected sources. + The catalog of rejected sky sources. + ``longTrailedDiaSources`` : `astropy.table.Table` + DiaSources which have trail lengths greater than + max_trail_length*exposure_time. """ rejectedSkySources = None + exposure_time = diffImVisitInfo.exposureTime if self.config.doRemoveSkySources: sky_source_column = diaSourceCat["sky_source"] num_sky_sources = np.sum(sky_source_column) @@ -120,6 +168,64 @@ def run(self, diaSourceCat): self.log.info(f"Filtered {num_sky_sources} sky sources.") if not rejectedSkySources: rejectedSkySources = SourceCatalog(diaSourceCat.getSchema()) - filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, - rejectedDiaSources=rejectedSkySources) + + if self.config.doTrailedSourceFilter: + trail_mask = self._check_dia_source_trail(diaSourceCat, exposure_time) + longTrailedDiaSources = diaSourceCat[trail_mask].copy(deep=True) + diaSourceCat = diaSourceCat[~trail_mask] + + self.log.info("%i DiaSources exceed max_trail_length %f arcseconds per second, " + "dropping from source catalog." + % (self.config.max_trail_length, len(diaSourceCat))) + self.metadata.add("num_filtered", len(longTrailedDiaSources)) + + if self.config.doWriteTrailedSources: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources, + longTrailedSources=longTrailedDiaSources.asAstropy()) + else: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources) + + else: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources) + return filterResults + + def _check_dia_source_trail(self, dia_sources, exposure_time): + """Find DiaSources that have long trails or trails with indeterminant + end points. + + Return a mask of sources with lengths greater than + (``config.max_trail_length`` multiplied by the exposure time) + arcseconds. + Additionally, set mask if + ``ext_trailedSources_Naive_flag_off_image`` is set or if + ``ext_trailedSources_Naive_flag_suspect_long_trail`` and + ``ext_trailedSources_Naive_flag_edge`` are both set. + + Parameters + ---------- + dia_sources : `pandas.DataFrame` + Input diaSources to check for trail lengths. + exposure_time : `float` + Exposure time from difference image. + + Returns + ------- + trail_mask : `pandas.DataFrame` + Boolean mask for diaSources which are greater than the + Boolean mask for diaSources which are greater than the + cutoff length or have trails which extend beyond the edge of the + detector (off_image set). Also checks if both + suspect_long_trail and edge are set and masks those sources out. + """ + print(dia_sources.getSchema()) + trail_mask = (dia_sources["ext_trailedSources_Naive_length"] + >= (self.config.max_trail_length*exposure_time)) + trail_mask |= dia_sources['ext_trailedSources_Naive_flag_off_image'] + trail_mask |= (dia_sources['ext_trailedSources_Naive_flag_suspect_long_trail'] + & dia_sources['ext_trailedSources_Naive_flag_edge']) + + return trail_mask diff --git a/python/lsst/ap/association/trailedSourceFilter.py b/python/lsst/ap/association/trailedSourceFilter.py deleted file mode 100644 index 8408ffb2..00000000 --- a/python/lsst/ap/association/trailedSourceFilter.py +++ /dev/null @@ -1,129 +0,0 @@ -# This file is part of ap_association. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig") - -import os -import numpy as np - -import lsst.pex.config as pexConfig -import lsst.pipe.base as pipeBase -from lsst.utils.timer import timeMethod -from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags -import lsst.utils as utils - - -class TrailedSourceFilterConfig(pexConfig.Config): - """Config class for TrailedSourceFilterTask. - """ - - max_trail_length = pexConfig.Field( - dtype=float, - doc="Length of long trailed sources to remove from the input catalog, " - "in arcseconds per second. Default comes from DMTN-199, which " - "requires removal of sources with trails longer than 10 " - "degrees/day, which is 36000/3600/24 arcsec/second, or roughly" - "0.416 arcseconds per second.", - default=36000/3600.0/24.0, - ) - - -class TrailedSourceFilterTask(pipeBase.Task): - """Find trailed sources in DIASources and filter them as per DMTN-199 - guidelines. - - This task checks the length of trailLength in the DIASource catalog using - a given arcsecond/second rate from max_trail_length and the exposure time. - The two values are used to calculate the maximum allowed trail length and - filters out any trail longer than the maximum. The max_trail_length is - outlined in DMTN-199 and determines the default value. - """ - - ConfigClass = TrailedSourceFilterConfig - _DefaultName = "trailedSourceFilter" - - @timeMethod - def run(self, dia_sources, exposure_time): - """Remove trailed sources longer than ``config.max_trail_length`` from - the input catalog. - - Parameters - ---------- - dia_sources : `pandas.DataFrame` - New DIASources to be checked for trailed sources. - exposure_time : `float` - Exposure time from difference image. - - Returns - ------- - result : `lsst.pipe.base.Struct` - Results struct with components. - - - ``diaSources`` : DIASource table that is free from unwanted - trailed sources. (`pandas.DataFrame`) - - - ``longTrailedDiaSources`` : DIASources that have trails which - exceed max_trail_length/second*exposure_time (seconds). - (`pandas.DataFrame`) - """ - - if "flags" in dia_sources.columns: - flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flag_map, "DiaSource") - flags = unpacker.unpack(dia_sources["flags"], "flags") - trail_edge_flags = flags["ext_trailedSources_Naive_flag_edge"] - else: - trail_edge_flags = dia_sources["trail_flag_edge"] - - trail_mask = self._check_dia_source_trail(dia_sources, exposure_time, trail_edge_flags) - - return pipeBase.Struct( - diaSources=dia_sources[~trail_mask].reset_index(drop=True), - longTrailedDiaSources=dia_sources[trail_mask].reset_index(drop=True)) - - def _check_dia_source_trail(self, dia_sources, exposure_time, trail_edge_flags): - """Find DiaSources that have long trails. - - Return a mask of sources with lengths greater than - ``config.max_trail_length`` multiplied by the exposure time in seconds - or have ext_trailedSources_Naive_flag_edge set. - - Parameters - ---------- - dia_sources : `pandas.DataFrame` - Input DIASources to check for trail lengths. - exposure_time : `float` - Exposure time from difference image. - trail_edge_flags : 'numpy.ndArray' - Boolean array of trail_flag_edge flags from the DIASources. - - Returns - ------- - trail_mask : `pandas.DataFrame` - Boolean mask for DIASources which are greater than the - cutoff length or have the edge flag set. - """ - trail_mask = (dia_sources.loc[:, "trailLength"].values[:] - >= (self.config.max_trail_length*exposure_time)) - - trail_mask[np.where(trail_edge_flags)] = True - - return trail_mask diff --git a/tests/test_association_task.py b/tests/test_association_task.py index 22b3d8fa..62873d73 100644 --- a/tests/test_association_task.py +++ b/tests/test_association_task.py @@ -54,16 +54,14 @@ def setUp(self): "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx, "flags": 0} for idx in range(self.nSources)]) - self.exposure_time = 30.0 def test_run(self): """Test the full task by associating a set of diaSources to existing diaObjects. """ config = AssociationTask.ConfigClass() - config.doTrailedSourceFilter = False assocTask = AssociationTask(config=config) - results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) + results = assocTask.run(self.diaSources, self.diaObjects) self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1) self.assertEqual(results.nUnassociatedDiaObjects, 1) @@ -73,31 +71,13 @@ def test_run(self): np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4]) np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) - def test_run_trailed_sources(self): - """Test the full task by associating a set of diaSources to - existing diaObjects when trailed sources are filtered. - - This should filter out two of the five sources based on trail length, - leaving one unassociated diaSource and two associated diaSources. - """ - assocTask = AssociationTask() - results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) - - self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3) - self.assertEqual(results.nUnassociatedDiaObjects, 3) - self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3) - self.assertEqual(len(results.unAssocDiaSources), 1) - np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2]) - np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) - def test_run_no_existing_objects(self): """Test the run method with a completely empty database. """ assocTask = AssociationTask() results = assocTask.run( self.diaSources, - pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]), - exposure_time=self.exposure_time) + pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"])) self.assertEqual(results.nUpdatedDiaObjects, 0) self.assertEqual(results.nUnassociatedDiaObjects, 0) self.assertEqual(len(results.matchedDiaSources), 0) diff --git a/tests/test_diaPipe.py b/tests/test_diaPipe.py index 1f3ec49e..edaed345 100644 --- a/tests/test_diaPipe.py +++ b/tests/test_diaPipe.py @@ -183,11 +183,10 @@ def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, diffIm) ssoAssocDiaSources=_makeMockDataFrame(), unAssocDiaSources=_makeMockDataFrame()) - def associator_run(table, diaObjects, exposure_time=None): + def associator_run(table, diaObjects): return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3, matchedDiaSources=_makeMockDataFrame(), - unAssocDiaSources=_makeMockDataFrame(), - longTrailedSources=None) + unAssocDiaSources=_makeMockDataFrame()) # apdb isn't a subtask, but still needs to be mocked out for correct # execution in the test environment. diff --git a/tests/test_filterDiaSourceCatalog.py b/tests/test_filterDiaSourceCatalog.py index 73cf1b5f..b813a3bd 100644 --- a/tests/test_filterDiaSourceCatalog.py +++ b/tests/test_filterDiaSourceCatalog.py @@ -26,12 +26,14 @@ import lsst.geom as geom import lsst.meas.base.tests as measTests import lsst.utils.tests +import lsst.afw.image as afwImage +import lsst.daf.base as dafBase class TestFilterDiaSourceCatalogTask(unittest.TestCase): def setUp(self): - self.nSources = 10 + self.nSources = 15 self.nSkySources = 5 self.yLoc = 100 self.expId = 4321 @@ -42,24 +44,70 @@ def setUp(self): dataset.addSource(10000.0, geom.Point2D(srcIdx, self.yLoc)) schema = dataset.makeMinimalSchema() schema.addField("sky_source", type="Flag", doc="Sky objects.") + schema.addField('ext_trailedSources_Naive_flag_off_image', type="Flag", + doc="Trail extends off image") + schema.addField('ext_trailedSources_Naive_flag_suspect_long_trail', + type="Flag", doc="Trail length is greater than three times the psf radius") + schema.addField('ext_trailedSources_Naive_flag_edge', type="Flag", + doc="Trail contains edge pixels") + schema.addField('ext_trailedSources_Naive_flag_nan', type="Flag", + doc="One or more trail coordinates are missing") + schema.addField('ext_trailedSources_Naive_length', type="F", + doc="Length of the source trail") _, self.diaSourceCat = dataset.realize(10.0, schema, randomSeed=1234) self.diaSourceCat[0:self.nSkySources]["sky_source"] = True + # The last 10 sources will all contained trail length measurements, + # increasing in size by 1.5 arcseconds. Only the last three will have + # lengths which are too long and will be filtered out. + self.nFilteredTrailedSources = 0 + for srcIdx in range(5, 15): + self.diaSourceCat[srcIdx]["ext_trailedSources_Naive_length"] = 1.5*(srcIdx-4) + if 1.5*(srcIdx-4) > 36000/3600.0/24.0 * 30.0: + self.nFilteredTrailedSources += 1 + # Setting a combination of flags for filtering in tests + self.diaSourceCat[5]["ext_trailedSources_Naive_flag_off_image"] = True + self.diaSourceCat[6]["ext_trailedSources_Naive_flag_suspect_long_trail"] = True + self.diaSourceCat[6]["ext_trailedSources_Naive_flag_edge"] = True + # As only two of these flags are set, the total number of filtered + # sources will be self.nFilteredTrailedSources + 2 + self.nFilteredTrailedSources += 2 self.config = FilterDiaSourceCatalogConfig() + mjd = 57071.0 + self.utc_jd = mjd + 2_400_000.5 - 35.0 / (24.0 * 60.0 * 60.0) + + self.visitInfo = afwImage.VisitInfo( + # This incomplete visitInfo is sufficient for testing because the + # Python constructor sets all other required values to some + # default. + exposureTime=30.0, + darkTime=3.0, + date=dafBase.DateTime(mjd, system=dafBase.DateTime.MJD), + boresightRaDec=geom.SpherePoint(0.0, 0.0, geom.degrees), + ) def test_run_without_filter(self): + """Test that when all filters are turned off all sources in the catalog + are returned. + """ self.config.doRemoveSkySources = False - self.config.doWriteRejectedSources = True + self.config.doWriteRejectedSkySources = False + self.config.doTrailedSourceFilter = False filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) - result = filterDiaSourceCatalogTask.run(self.diaSourceCat) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat)) self.assertEqual(len(result.rejectedDiaSources), 0) self.assertEqual(len(self.diaSourceCat), self.nSources) - def test_run_with_filter(self): + def test_run_with_filter_sky_only(self): + """Test that when only the sky filter is turned on the first five + sources which are flagged as sky objects are filtered out of the + catalog and the rest are returned. + """ self.config.doRemoveSkySources = True - self.config.doWriteRejectedSources = True + self.config.doWriteRejectedSkySources = True + self.config.doTrailedSourceFilter = False filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) - result = filterDiaSourceCatalogTask.run(self.diaSourceCat) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) nExpectedFilteredSources = self.nSources - self.nSkySources self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat[~self.diaSourceCat['sky_source']])) @@ -67,6 +115,39 @@ def test_run_with_filter(self): self.assertEqual(len(result.rejectedDiaSources), self.nSkySources) self.assertEqual(len(self.diaSourceCat), self.nSources) + def test_run_with_filter_trailed_sources_only(self): + """Test that when only the trail filter is turned on the correct number + of sources are filtered out. The filtered sources should be the last + three sources which have long trails, one source where both the suspect + trail and edge trail flag are set, and one source where off_image is + set. All sky objects should remain in the catalog. + """ + self.config.doRemoveSkySources = False + self.config.doWriteRejectedSkySources = False + self.config.doTrailedSourceFilter = True + filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) + nExpectedFilteredSources = self.nSources - self.nFilteredTrailedSources + self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) + self.assertEqual(len(self.diaSourceCat), self.nSources) + + def test_run_with_all_filters(self): + """Test that all sources are filtered out correctly. Only six sources + should remain in the catalog after filtering. + """ + self.config.doRemoveSkySources = True + self.config.doWriteRejectedSkySources = True + self.config.doTrailedSourceFilter = True + filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) + nExpectedFilteredSources = self.nSources - self.nSkySources - self.nFilteredTrailedSources + # 5 filtered out sky sources + # 4 filtered out trailed sources, 2 with long trails 2 with flags + # 6 sources left + self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) + self.assertEqual(len(result.rejectedDiaSources), self.nSkySources) + self.assertEqual(len(self.diaSourceCat), self.nSources) + class MemoryTester(lsst.utils.tests.MemoryTestCase): pass diff --git a/tests/test_trailedSourceFilter.py b/tests/test_trailedSourceFilter.py deleted file mode 100644 index e777aaea..00000000 --- a/tests/test_trailedSourceFilter.py +++ /dev/null @@ -1,167 +0,0 @@ -# This file is part of ap_association. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -import unittest -import os -import numpy as np -import pandas as pd - -import lsst.utils.tests -import lsst.utils as utils -from lsst.ap.association import TrailedSourceFilterTask -from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags - - -class TestTrailedSourceFilterTask(unittest.TestCase): - - def setUp(self): - """Create sets of diaSources. - - The trail lengths of the dia sources are 0, 5.5, 11, 16.5, 21.5 - arcseconds. - """ - # Create an instance of random generator with fixed seed. - rng = np.random.default_rng(1234) - - scatter = 0.1 / 3600 - self.nSources = 5 - self.diaSources = pd.DataFrame(data=[ - {"ra": 0.04*idx + scatter*rng.uniform(-1, 1), - "dec": 0.04*idx + scatter*rng.uniform(-1, 1), - "diaSourceId": idx, "diaObjectId": 0, "trailLength": 5.5*idx, - "flags": 0} - for idx in range(self.nSources)]) - self.exposure_time = 30.0 - - # For use only with testing the edge flag - self.edgeDiaSources = pd.DataFrame(data=[ - {"ra": 0.04*idx + scatter*rng.uniform(-1, 1), - "dec": 0.04*idx + scatter*rng.uniform(-1, 1), - "diaSourceId": idx, "diaObjectId": 0, "trailLength": 0, - "flags": 0} - for idx in range(self.nSources)]) - - flagMap = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flagMap, "DiaSource") - bitMask = unpacker.makeFlagBitMask(["ext_trailedSources_Naive_flag_edge"]) - # Flag two sources as "trailed on the edge". - self.edgeDiaSources.loc[[1, 4], "flags"] |= bitMask - - def test_run(self): - """Run trailedSourceFilterTask with the default max distance. - - With the default settings and an exposure of 30 seconds, the max trail - length is 12.5 arcseconds. Two out of five of the diaSources will be - filtered out of the final results and put into results.trailedSources. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 3) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 1, 2]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [3, 4]) - - def test_run_short_max_trail(self): - """Run trailedSourceFilterTask with aggressive trail length cutoff - - With a max_trail_length config of 0.01 arcseconds/second and an - exposure of 30 seconds,the max trail length is 0.3 arcseconds. Only the - source with a trail of 0 stays in the catalog and the rest are filtered - out and put into results.trailedSources. - """ - config = TrailedSourceFilterTask.ConfigClass() - config.max_trail_length = 0.01 - trailedSourceFilterTask = TrailedSourceFilterTask(config=config) - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 1) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 2, 3, 4]) - - def test_run_no_trails(self): - """Run trailedSourceFilterTask with a long trail length so that - every source in the catalog is in the final diaSource catalog. - - With a max_trail_length config of 10 arcseconds/second and an - exposure of 30 seconds,the max trail length is 300 arcseconds. All - sources in the initial catalog should be in the final diaSource - catalog. - """ - config = TrailedSourceFilterTask.ConfigClass() - config.max_trail_length = 10.00 - trailedSourceFilterTask = TrailedSourceFilterTask(config=config) - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 5) - self.assertEqual(len(results.longTrailedDiaSources), 0) - np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4]) - np.testing.assert_array_equal(results.longTrailedDiaSources["diaSourceId"].values, []) - - def test_run_edge(self): - """Run trailedSourceFilterTask on a source on the edge. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - - results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 3) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 2, 3]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 4]) - - def test_check_dia_source_trail(self): - """Test that the DiaSource trail checker is correctly identifying - long trails - - Test that the trail source mask filter returns the expected mask array. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flag_map, "DiaSource") - flags = unpacker.unpack(self.diaSources["flags"], "flags") - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, - self.exposure_time, flags) - - np.testing.assert_array_equal(trailed_source_mask, [False, False, False, True, True]) - - flags = unpacker.unpack(self.edgeDiaSources["flags"], "flags") - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.edgeDiaSources, - self.exposure_time, flags) - np.testing.assert_array_equal(trailed_source_mask, [False, True, False, False, True]) - - # Mixing the flags from edgeDiaSources and diaSources means the mask - # will be set using both criteria. - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, - self.exposure_time, flags) - np.testing.assert_array_equal(trailed_source_mask, [False, True, False, True, True]) - - -class MemoryTester(lsst.utils.tests.MemoryTestCase): - pass - - -def setup_module(module): - lsst.utils.tests.init() - - -if __name__ == "__main__": - lsst.utils.tests.init() - unittest.main()