From c5926b6244c73170f15a16324c56ca8ce2305eed Mon Sep 17 00:00:00 2001 From: Brianna Smart Date: Tue, 7 May 2024 16:39:59 -0700 Subject: [PATCH] Move trailed source filtering to filterDiaSourceCatalog Move the trailed source filtering into filterDiaSourceCatalog. Additionally, update filtering so only trails that extend of the edge or which have long trails and are near the edge are filtered out, along with the sources which exceed the max trail length. --- data/association-flag-map.yaml | 2 +- python/lsst/ap/association/__init__.py | 1 - python/lsst/ap/association/association.py | 44 +---- python/lsst/ap/association/diaPipe.py | 26 +-- .../ap/association/filterDiaSourceCatalog.py | 118 ++++++++++++- .../ap/association/trailedSourceFilter.py | 129 -------------- tests/test_association_task.py | 24 +-- tests/test_diaPipe.py | 5 +- tests/test_filterDiaSourceCatalog.py | 93 +++++++++- tests/test_trailedSourceFilter.py | 167 ------------------ 10 files changed, 214 insertions(+), 395 deletions(-) delete mode 100644 python/lsst/ap/association/trailedSourceFilter.py delete mode 100644 tests/test_trailedSourceFilter.py diff --git a/data/association-flag-map.yaml b/data/association-flag-map.yaml index 3e8176d0..af7edc42 100644 --- a/data/association-flag-map.yaml +++ b/data/association-flag-map.yaml @@ -102,4 +102,4 @@ columns: doc: Fake source template injection in source footprint - name: base_PixelFlags_flag_injected_templateCenter bit: 33 - doc: Fake source template injection center in source footprint + doc: Fake source template injection center in source footprint \ No newline at end of file diff --git a/python/lsst/ap/association/__init__.py b/python/lsst/ap/association/__init__.py index 16cd34da..fd84f9a7 100644 --- a/python/lsst/ap/association/__init__.py +++ b/python/lsst/ap/association/__init__.py @@ -25,6 +25,5 @@ from .loadDiaCatalogs import * from .packageAlerts import * from .diaPipe import * -from .trailedSourceFilter import * from .transformDiaSourceCatalog import * from .version import * \ No newline at end of file diff --git a/python/lsst/ap/association/association.py b/python/lsst/ap/association/association.py index d2bd043a..eb4e0e41 100644 --- a/python/lsst/ap/association/association.py +++ b/python/lsst/ap/association/association.py @@ -32,7 +32,6 @@ import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase from lsst.utils.timer import timeMethod -from .trailedSourceFilter import TrailedSourceFilterTask # Enforce an error for unsafe column/array value setting in pandas. pd.options.mode.chained_assignment = 'raise' @@ -49,19 +48,6 @@ class AssociationConfig(pexConfig.Config): default=1.0, ) - trailedSourceFilter = pexConfig.ConfigurableField( - target=TrailedSourceFilterTask, - doc="Subtask to remove long trailed sources based on catalog source " - "morphological measurements.", - ) - - doTrailedSourceFilter = pexConfig.Field( - doc="Run traildeSourceFilter to remove long trailed sources from " - "output catalog.", - dtype=bool, - default=True, - ) - class AssociationTask(pipeBase.Task): """Associate DIAOSources into existing DIAObjects. @@ -75,16 +61,10 @@ class AssociationTask(pipeBase.Task): ConfigClass = AssociationConfig _DefaultName = "association" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.config.doTrailedSourceFilter: - self.makeSubtask("trailedSourceFilter") - @timeMethod def run(self, diaSources, - diaObjects, - exposure_time=None): + diaObjects): """Associate the new DiaSources with existing DiaObjects. Parameters @@ -93,8 +73,6 @@ def run(self, New DIASources to be associated with existing DIAObjects. diaObjects : `pandas.DataFrame` Existing diaObjects from the Apdb. - exposure_time : `float`, optional - Exposure time from difference image. Returns ------- @@ -112,30 +90,15 @@ def run(self, matched to new DiaSources. (`int`) - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were not matched a new DiaSource. (`int`) - - ``longTrailedSources`` : DiaSources which have trail lengths - greater than max_trail_length/second*exposure_time. - (`pandas.DataFrame``) """ diaSources = self.check_dia_source_radec(diaSources) - if self.config.doTrailedSourceFilter: - diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time) - diaSources = diaTrailedResult.diaSources - longTrailedSources = diaTrailedResult.longTrailedDiaSources - - self.log.info("%i DiaSources exceed max_trail_length, dropping from source " - "catalog." % len(diaTrailedResult.longTrailedDiaSources)) - self.metadata.add("num_filtered", len(diaTrailedResult.longTrailedDiaSources)) - else: - longTrailedSources = pd.DataFrame(columns=diaSources.columns) - if len(diaObjects) == 0: return pipeBase.Struct( matchedDiaSources=pd.DataFrame(columns=diaSources.columns), unAssocDiaSources=diaSources, nUpdatedDiaObjects=0, - nUnassociatedDiaObjects=0, - longTrailedSources=longTrailedSources) + nUnassociatedDiaObjects=0) matchResult = self.associate_sources(diaObjects, diaSources) @@ -145,8 +108,7 @@ def run(self, matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True), unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True), nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects, - nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects, - longTrailedSources=longTrailedSources) + nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects) def check_dia_source_radec(self, dia_sources): """Check that all DiaSources have non-NaN values for RA/DEC. diff --git a/python/lsst/ap/association/diaPipe.py b/python/lsst/ap/association/diaPipe.py index 9ee2d556..3aacd4b3 100644 --- a/python/lsst/ap/association/diaPipe.py +++ b/python/lsst/ap/association/diaPipe.py @@ -35,23 +35,22 @@ import warnings -import numpy as np -import pandas as pd - -from lsst.daf.base import DateTime import lsst.dax.apdb as daxApdb -from lsst.meas.base import DetectorVisitIdGeneratorConfig, DiaObjectCalculationTask import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase import lsst.pipe.base.connectionTypes as connTypes -from lsst.utils.timer import timeMethod - +import numpy as np +import pandas as pd from lsst.ap.association import ( AssociationTask, DiaForcedSourceTask, LoadDiaCatalogsTask, PackageAlertsTask) from lsst.ap.association.ssoAssociation import SolarSystemAssociationTask +from lsst.daf.base import DateTime +from lsst.meas.base import DetectorVisitIdGeneratorConfig, \ + DiaObjectCalculationTask +from lsst.utils.timer import timeMethod class DiaPipelineConnections( @@ -119,12 +118,6 @@ class DiaPipelineConnections( storageClass="DataFrame", dimensions=("instrument", "visit", "detector"), ) - longTrailedSources = pipeBase.connectionTypes.Output( - doc="Optional output temporarily storing long trailed diaSources.", - dimensions=("instrument", "visit", "detector"), - storageClass="DataFrame", - name="{fakesType}{coaddName}Diff_longTrailedSrc", - ) def __init__(self, *, config=None): super().__init__(config=config) @@ -137,8 +130,6 @@ def __init__(self, *, config=None): self.outputs.remove("diaForcedSources") if not config.doSolarSystemAssociation: self.inputs.remove("solarSystemObjectTable") - if not config.associator.doTrailedSourceFilter: - self.outputs.remove("longTrailedSources") def adjustQuantum(self, inputs, outputs, label, dataId): """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects @@ -438,10 +429,8 @@ def run(self, buffer=self.config.imagePixelMargin) else: diaObjects = loaderResult.diaObjects - # Associate new DiaSources with existing DiaObjects. - assocResults = self.associator.run(diaSourceTable, diaObjects, - exposure_time=diffIm.visitInfo.exposureTime) + assocResults = self.associator.run(diaSourceTable, diaObjects) if self.config.doSolarSystemAssociation: ssoAssocResult = self.solarSystemAssociator.run( @@ -627,7 +616,6 @@ def run(self, associatedDiaSources=associatedDiaSources, diaForcedSources=diaForcedSources, diaObjects=diaObjects, - longTrailedSources=assocResults.longTrailedSources ) def createNewDiaObjects(self, unAssocDiaSources): diff --git a/python/lsst/ap/association/filterDiaSourceCatalog.py b/python/lsst/ap/association/filterDiaSourceCatalog.py index 9330e1e9..0c8f1c84 100644 --- a/python/lsst/ap/association/filterDiaSourceCatalog.py +++ b/python/lsst/ap/association/filterDiaSourceCatalog.py @@ -61,10 +61,28 @@ class FilterDiaSourceCatalogConnections( dimensions={"instrument", "visit", "detector"}, ) + diffImVisitInfo = connTypes.Input( + doc="VisitInfo of diffIm.", + name="{fakesType}{coaddName}Diff_differenceExp.visitInfo", + storageClass="VisitInfo", + dimensions=("instrument", "visit", "detector"), + ) + + longTrailedSources = connTypes.Output( + doc="Optional output temporarily storing long trailed diaSources.", + dimensions=("instrument", "visit", "detector"), + storageClass="ArrowAstropy", + name="{fakesType}{coaddName}Diff_longTrailedSrc", + ) + def __init__(self, *, config=None): super().__init__(config=config) - if not self.config.doWriteRejectedSources: + if not self.config.doWriteRejectedSkySources: self.outputs.remove("rejectedDiaSources") + if not self.config.doTrailedSourceFilter: + self.outputs.remove("longTrailedSources") + if not self.config.doWriteTrailedSources: + self.outputs.remove("longTrailedSources") class FilterDiaSourceCatalogConfig( @@ -79,13 +97,37 @@ class FilterDiaSourceCatalogConfig( "removed before storing the output DiaSource catalog.", ) - doWriteRejectedSources = pexConfig.Field( + doWriteRejectedSkySources = pexConfig.Field( dtype=bool, default=True, doc="Store the output DiaSource catalog containing all the rejected " "sky sources." ) + doTrailedSourceFilter = pexConfig.Field( + doc="Run trailedSourceFilter to remove long trailed sources from the" + "diaSource output catalog.", + dtype=bool, + default=True, + ) + + doWriteTrailedSources = pexConfig.Field( + doc="Write trailed diaSources sources to a table.", + dtype=bool, + default=True, + deprecated="Trailed sources will not be written out during production." + ) + + max_trail_length = pexConfig.Field( + dtype=float, + doc="Length of long trailed sources to remove from the input catalog, " + "in arcseconds per second. Default comes from DMTN-199, which " + "requires removal of sources with trails longer than 10 " + "degrees/day, which is 36000/3600/24 arcsec/second, or roughly" + "0.416 arcseconds per second.", + default=36000/3600.0/24.0, + ) + class FilterDiaSourceCatalogTask(pipeBase.PipelineTask): """Filter out sky sources from a DiaSource catalog.""" @@ -94,13 +136,15 @@ class FilterDiaSourceCatalogTask(pipeBase.PipelineTask): _DefaultName = "filterDiaSourceCatalog" @timeMethod - def run(self, diaSourceCat): + def run(self, diaSourceCat, diffImVisitInfo): """Filter sky sources from the supplied DiaSource catalog. Parameters ---------- diaSourceCat : `lsst.afw.table.SourceCatalog` Catalog of sources measured on the difference image. + diffImVisitInfo: `lsst.afw.image.VisitInfo` + VisitInfo for the difference image corresponding to diaSourceCat. Returns ------- @@ -109,9 +153,13 @@ def run(self, diaSourceCat): ``filteredDiaSourceCat`` : `lsst.afw.table.SourceCatalog` The catalog of filtered sources. ``rejectedDiaSources`` : `lsst.afw.table.SourceCatalog` - The catalog of rejected sources. + The catalog of rejected sky sources. + ``longTrailedDiaSources`` : `astropy.table.Table` + DiaSources which have trail lengths greater than + max_trail_length*exposure_time. """ rejectedSkySources = None + exposure_time = diffImVisitInfo.exposureTime if self.config.doRemoveSkySources: sky_source_column = diaSourceCat["sky_source"] num_sky_sources = np.sum(sky_source_column) @@ -120,6 +168,64 @@ def run(self, diaSourceCat): self.log.info(f"Filtered {num_sky_sources} sky sources.") if not rejectedSkySources: rejectedSkySources = SourceCatalog(diaSourceCat.getSchema()) - filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, - rejectedDiaSources=rejectedSkySources) + + if self.config.doTrailedSourceFilter: + trail_mask = self._check_dia_source_trail(diaSourceCat, exposure_time) + longTrailedDiaSources = diaSourceCat[trail_mask].copy(deep=True) + diaSourceCat = diaSourceCat[~trail_mask] + + self.log.info("%i DiaSources exceed max_trail_length %f arcseconds per second, " + "dropping from source catalog." + % (self.config.max_trail_length, len(diaSourceCat))) + self.metadata.add("num_filtered", len(longTrailedDiaSources)) + + if self.config.doWriteTrailedSources: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources, + longTrailedSources=longTrailedDiaSources.asAstropy()) + else: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources) + + else: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources) + return filterResults + + def _check_dia_source_trail(self, dia_sources, exposure_time): + """Find DiaSources that have long trails or trails with indeterminant + end points. + + Return a mask of sources with lengths greater than + (``config.max_trail_length`` multiplied by the exposure time) + arcseconds. + Additionally, set mask if + ``ext_trailedSources_Naive_flag_off_image`` is set or if + ``ext_trailedSources_Naive_flag_suspect_long_trail`` and + ``ext_trailedSources_Naive_flag_edge`` are both set. + + Parameters + ---------- + dia_sources : `pandas.DataFrame` + Input DIASources to check for trail lengths. + exposure_time : `float` + Exposure time from difference image. + + Returns + ------- + trail_mask : `pandas.DataFrame` + Boolean mask for DIASources which are greater than the + Boolean mask for DIASources which are greater than the + cutoff length or have trails which extend beyond the edge of the + detector (off_image set). Also checks if both + suspect_long_trail and edge are set and masks those sources out. + """ + print(dia_sources.getSchema()) + trail_mask = (dia_sources["ext_trailedSources_Naive_length"] + >= (self.config.max_trail_length*exposure_time)) + trail_mask |= dia_sources['ext_trailedSources_Naive_flag_off_image'] + trail_mask |= (dia_sources['ext_trailedSources_Naive_flag_suspect_long_trail'] + & dia_sources['ext_trailedSources_Naive_flag_edge']) + + return trail_mask diff --git a/python/lsst/ap/association/trailedSourceFilter.py b/python/lsst/ap/association/trailedSourceFilter.py deleted file mode 100644 index 8408ffb2..00000000 --- a/python/lsst/ap/association/trailedSourceFilter.py +++ /dev/null @@ -1,129 +0,0 @@ -# This file is part of ap_association. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig") - -import os -import numpy as np - -import lsst.pex.config as pexConfig -import lsst.pipe.base as pipeBase -from lsst.utils.timer import timeMethod -from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags -import lsst.utils as utils - - -class TrailedSourceFilterConfig(pexConfig.Config): - """Config class for TrailedSourceFilterTask. - """ - - max_trail_length = pexConfig.Field( - dtype=float, - doc="Length of long trailed sources to remove from the input catalog, " - "in arcseconds per second. Default comes from DMTN-199, which " - "requires removal of sources with trails longer than 10 " - "degrees/day, which is 36000/3600/24 arcsec/second, or roughly" - "0.416 arcseconds per second.", - default=36000/3600.0/24.0, - ) - - -class TrailedSourceFilterTask(pipeBase.Task): - """Find trailed sources in DIASources and filter them as per DMTN-199 - guidelines. - - This task checks the length of trailLength in the DIASource catalog using - a given arcsecond/second rate from max_trail_length and the exposure time. - The two values are used to calculate the maximum allowed trail length and - filters out any trail longer than the maximum. The max_trail_length is - outlined in DMTN-199 and determines the default value. - """ - - ConfigClass = TrailedSourceFilterConfig - _DefaultName = "trailedSourceFilter" - - @timeMethod - def run(self, dia_sources, exposure_time): - """Remove trailed sources longer than ``config.max_trail_length`` from - the input catalog. - - Parameters - ---------- - dia_sources : `pandas.DataFrame` - New DIASources to be checked for trailed sources. - exposure_time : `float` - Exposure time from difference image. - - Returns - ------- - result : `lsst.pipe.base.Struct` - Results struct with components. - - - ``diaSources`` : DIASource table that is free from unwanted - trailed sources. (`pandas.DataFrame`) - - - ``longTrailedDiaSources`` : DIASources that have trails which - exceed max_trail_length/second*exposure_time (seconds). - (`pandas.DataFrame`) - """ - - if "flags" in dia_sources.columns: - flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flag_map, "DiaSource") - flags = unpacker.unpack(dia_sources["flags"], "flags") - trail_edge_flags = flags["ext_trailedSources_Naive_flag_edge"] - else: - trail_edge_flags = dia_sources["trail_flag_edge"] - - trail_mask = self._check_dia_source_trail(dia_sources, exposure_time, trail_edge_flags) - - return pipeBase.Struct( - diaSources=dia_sources[~trail_mask].reset_index(drop=True), - longTrailedDiaSources=dia_sources[trail_mask].reset_index(drop=True)) - - def _check_dia_source_trail(self, dia_sources, exposure_time, trail_edge_flags): - """Find DiaSources that have long trails. - - Return a mask of sources with lengths greater than - ``config.max_trail_length`` multiplied by the exposure time in seconds - or have ext_trailedSources_Naive_flag_edge set. - - Parameters - ---------- - dia_sources : `pandas.DataFrame` - Input DIASources to check for trail lengths. - exposure_time : `float` - Exposure time from difference image. - trail_edge_flags : 'numpy.ndArray' - Boolean array of trail_flag_edge flags from the DIASources. - - Returns - ------- - trail_mask : `pandas.DataFrame` - Boolean mask for DIASources which are greater than the - cutoff length or have the edge flag set. - """ - trail_mask = (dia_sources.loc[:, "trailLength"].values[:] - >= (self.config.max_trail_length*exposure_time)) - - trail_mask[np.where(trail_edge_flags)] = True - - return trail_mask diff --git a/tests/test_association_task.py b/tests/test_association_task.py index 22b3d8fa..62873d73 100644 --- a/tests/test_association_task.py +++ b/tests/test_association_task.py @@ -54,16 +54,14 @@ def setUp(self): "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx, "flags": 0} for idx in range(self.nSources)]) - self.exposure_time = 30.0 def test_run(self): """Test the full task by associating a set of diaSources to existing diaObjects. """ config = AssociationTask.ConfigClass() - config.doTrailedSourceFilter = False assocTask = AssociationTask(config=config) - results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) + results = assocTask.run(self.diaSources, self.diaObjects) self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1) self.assertEqual(results.nUnassociatedDiaObjects, 1) @@ -73,31 +71,13 @@ def test_run(self): np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4]) np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) - def test_run_trailed_sources(self): - """Test the full task by associating a set of diaSources to - existing diaObjects when trailed sources are filtered. - - This should filter out two of the five sources based on trail length, - leaving one unassociated diaSource and two associated diaSources. - """ - assocTask = AssociationTask() - results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) - - self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3) - self.assertEqual(results.nUnassociatedDiaObjects, 3) - self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3) - self.assertEqual(len(results.unAssocDiaSources), 1) - np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2]) - np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) - def test_run_no_existing_objects(self): """Test the run method with a completely empty database. """ assocTask = AssociationTask() results = assocTask.run( self.diaSources, - pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]), - exposure_time=self.exposure_time) + pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"])) self.assertEqual(results.nUpdatedDiaObjects, 0) self.assertEqual(results.nUnassociatedDiaObjects, 0) self.assertEqual(len(results.matchedDiaSources), 0) diff --git a/tests/test_diaPipe.py b/tests/test_diaPipe.py index 1f3ec49e..edaed345 100644 --- a/tests/test_diaPipe.py +++ b/tests/test_diaPipe.py @@ -183,11 +183,10 @@ def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, diffIm) ssoAssocDiaSources=_makeMockDataFrame(), unAssocDiaSources=_makeMockDataFrame()) - def associator_run(table, diaObjects, exposure_time=None): + def associator_run(table, diaObjects): return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3, matchedDiaSources=_makeMockDataFrame(), - unAssocDiaSources=_makeMockDataFrame(), - longTrailedSources=None) + unAssocDiaSources=_makeMockDataFrame()) # apdb isn't a subtask, but still needs to be mocked out for correct # execution in the test environment. diff --git a/tests/test_filterDiaSourceCatalog.py b/tests/test_filterDiaSourceCatalog.py index 73cf1b5f..b813a3bd 100644 --- a/tests/test_filterDiaSourceCatalog.py +++ b/tests/test_filterDiaSourceCatalog.py @@ -26,12 +26,14 @@ import lsst.geom as geom import lsst.meas.base.tests as measTests import lsst.utils.tests +import lsst.afw.image as afwImage +import lsst.daf.base as dafBase class TestFilterDiaSourceCatalogTask(unittest.TestCase): def setUp(self): - self.nSources = 10 + self.nSources = 15 self.nSkySources = 5 self.yLoc = 100 self.expId = 4321 @@ -42,24 +44,70 @@ def setUp(self): dataset.addSource(10000.0, geom.Point2D(srcIdx, self.yLoc)) schema = dataset.makeMinimalSchema() schema.addField("sky_source", type="Flag", doc="Sky objects.") + schema.addField('ext_trailedSources_Naive_flag_off_image', type="Flag", + doc="Trail extends off image") + schema.addField('ext_trailedSources_Naive_flag_suspect_long_trail', + type="Flag", doc="Trail length is greater than three times the psf radius") + schema.addField('ext_trailedSources_Naive_flag_edge', type="Flag", + doc="Trail contains edge pixels") + schema.addField('ext_trailedSources_Naive_flag_nan', type="Flag", + doc="One or more trail coordinates are missing") + schema.addField('ext_trailedSources_Naive_length', type="F", + doc="Length of the source trail") _, self.diaSourceCat = dataset.realize(10.0, schema, randomSeed=1234) self.diaSourceCat[0:self.nSkySources]["sky_source"] = True + # The last 10 sources will all contained trail length measurements, + # increasing in size by 1.5 arcseconds. Only the last three will have + # lengths which are too long and will be filtered out. + self.nFilteredTrailedSources = 0 + for srcIdx in range(5, 15): + self.diaSourceCat[srcIdx]["ext_trailedSources_Naive_length"] = 1.5*(srcIdx-4) + if 1.5*(srcIdx-4) > 36000/3600.0/24.0 * 30.0: + self.nFilteredTrailedSources += 1 + # Setting a combination of flags for filtering in tests + self.diaSourceCat[5]["ext_trailedSources_Naive_flag_off_image"] = True + self.diaSourceCat[6]["ext_trailedSources_Naive_flag_suspect_long_trail"] = True + self.diaSourceCat[6]["ext_trailedSources_Naive_flag_edge"] = True + # As only two of these flags are set, the total number of filtered + # sources will be self.nFilteredTrailedSources + 2 + self.nFilteredTrailedSources += 2 self.config = FilterDiaSourceCatalogConfig() + mjd = 57071.0 + self.utc_jd = mjd + 2_400_000.5 - 35.0 / (24.0 * 60.0 * 60.0) + + self.visitInfo = afwImage.VisitInfo( + # This incomplete visitInfo is sufficient for testing because the + # Python constructor sets all other required values to some + # default. + exposureTime=30.0, + darkTime=3.0, + date=dafBase.DateTime(mjd, system=dafBase.DateTime.MJD), + boresightRaDec=geom.SpherePoint(0.0, 0.0, geom.degrees), + ) def test_run_without_filter(self): + """Test that when all filters are turned off all sources in the catalog + are returned. + """ self.config.doRemoveSkySources = False - self.config.doWriteRejectedSources = True + self.config.doWriteRejectedSkySources = False + self.config.doTrailedSourceFilter = False filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) - result = filterDiaSourceCatalogTask.run(self.diaSourceCat) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat)) self.assertEqual(len(result.rejectedDiaSources), 0) self.assertEqual(len(self.diaSourceCat), self.nSources) - def test_run_with_filter(self): + def test_run_with_filter_sky_only(self): + """Test that when only the sky filter is turned on the first five + sources which are flagged as sky objects are filtered out of the + catalog and the rest are returned. + """ self.config.doRemoveSkySources = True - self.config.doWriteRejectedSources = True + self.config.doWriteRejectedSkySources = True + self.config.doTrailedSourceFilter = False filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) - result = filterDiaSourceCatalogTask.run(self.diaSourceCat) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) nExpectedFilteredSources = self.nSources - self.nSkySources self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat[~self.diaSourceCat['sky_source']])) @@ -67,6 +115,39 @@ def test_run_with_filter(self): self.assertEqual(len(result.rejectedDiaSources), self.nSkySources) self.assertEqual(len(self.diaSourceCat), self.nSources) + def test_run_with_filter_trailed_sources_only(self): + """Test that when only the trail filter is turned on the correct number + of sources are filtered out. The filtered sources should be the last + three sources which have long trails, one source where both the suspect + trail and edge trail flag are set, and one source where off_image is + set. All sky objects should remain in the catalog. + """ + self.config.doRemoveSkySources = False + self.config.doWriteRejectedSkySources = False + self.config.doTrailedSourceFilter = True + filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) + nExpectedFilteredSources = self.nSources - self.nFilteredTrailedSources + self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) + self.assertEqual(len(self.diaSourceCat), self.nSources) + + def test_run_with_all_filters(self): + """Test that all sources are filtered out correctly. Only six sources + should remain in the catalog after filtering. + """ + self.config.doRemoveSkySources = True + self.config.doWriteRejectedSkySources = True + self.config.doTrailedSourceFilter = True + filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) + nExpectedFilteredSources = self.nSources - self.nSkySources - self.nFilteredTrailedSources + # 5 filtered out sky sources + # 4 filtered out trailed sources, 2 with long trails 2 with flags + # 6 sources left + self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) + self.assertEqual(len(result.rejectedDiaSources), self.nSkySources) + self.assertEqual(len(self.diaSourceCat), self.nSources) + class MemoryTester(lsst.utils.tests.MemoryTestCase): pass diff --git a/tests/test_trailedSourceFilter.py b/tests/test_trailedSourceFilter.py deleted file mode 100644 index e777aaea..00000000 --- a/tests/test_trailedSourceFilter.py +++ /dev/null @@ -1,167 +0,0 @@ -# This file is part of ap_association. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -import unittest -import os -import numpy as np -import pandas as pd - -import lsst.utils.tests -import lsst.utils as utils -from lsst.ap.association import TrailedSourceFilterTask -from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags - - -class TestTrailedSourceFilterTask(unittest.TestCase): - - def setUp(self): - """Create sets of diaSources. - - The trail lengths of the dia sources are 0, 5.5, 11, 16.5, 21.5 - arcseconds. - """ - # Create an instance of random generator with fixed seed. - rng = np.random.default_rng(1234) - - scatter = 0.1 / 3600 - self.nSources = 5 - self.diaSources = pd.DataFrame(data=[ - {"ra": 0.04*idx + scatter*rng.uniform(-1, 1), - "dec": 0.04*idx + scatter*rng.uniform(-1, 1), - "diaSourceId": idx, "diaObjectId": 0, "trailLength": 5.5*idx, - "flags": 0} - for idx in range(self.nSources)]) - self.exposure_time = 30.0 - - # For use only with testing the edge flag - self.edgeDiaSources = pd.DataFrame(data=[ - {"ra": 0.04*idx + scatter*rng.uniform(-1, 1), - "dec": 0.04*idx + scatter*rng.uniform(-1, 1), - "diaSourceId": idx, "diaObjectId": 0, "trailLength": 0, - "flags": 0} - for idx in range(self.nSources)]) - - flagMap = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flagMap, "DiaSource") - bitMask = unpacker.makeFlagBitMask(["ext_trailedSources_Naive_flag_edge"]) - # Flag two sources as "trailed on the edge". - self.edgeDiaSources.loc[[1, 4], "flags"] |= bitMask - - def test_run(self): - """Run trailedSourceFilterTask with the default max distance. - - With the default settings and an exposure of 30 seconds, the max trail - length is 12.5 arcseconds. Two out of five of the diaSources will be - filtered out of the final results and put into results.trailedSources. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 3) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 1, 2]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [3, 4]) - - def test_run_short_max_trail(self): - """Run trailedSourceFilterTask with aggressive trail length cutoff - - With a max_trail_length config of 0.01 arcseconds/second and an - exposure of 30 seconds,the max trail length is 0.3 arcseconds. Only the - source with a trail of 0 stays in the catalog and the rest are filtered - out and put into results.trailedSources. - """ - config = TrailedSourceFilterTask.ConfigClass() - config.max_trail_length = 0.01 - trailedSourceFilterTask = TrailedSourceFilterTask(config=config) - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 1) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 2, 3, 4]) - - def test_run_no_trails(self): - """Run trailedSourceFilterTask with a long trail length so that - every source in the catalog is in the final diaSource catalog. - - With a max_trail_length config of 10 arcseconds/second and an - exposure of 30 seconds,the max trail length is 300 arcseconds. All - sources in the initial catalog should be in the final diaSource - catalog. - """ - config = TrailedSourceFilterTask.ConfigClass() - config.max_trail_length = 10.00 - trailedSourceFilterTask = TrailedSourceFilterTask(config=config) - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 5) - self.assertEqual(len(results.longTrailedDiaSources), 0) - np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4]) - np.testing.assert_array_equal(results.longTrailedDiaSources["diaSourceId"].values, []) - - def test_run_edge(self): - """Run trailedSourceFilterTask on a source on the edge. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - - results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 3) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 2, 3]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 4]) - - def test_check_dia_source_trail(self): - """Test that the DiaSource trail checker is correctly identifying - long trails - - Test that the trail source mask filter returns the expected mask array. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flag_map, "DiaSource") - flags = unpacker.unpack(self.diaSources["flags"], "flags") - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, - self.exposure_time, flags) - - np.testing.assert_array_equal(trailed_source_mask, [False, False, False, True, True]) - - flags = unpacker.unpack(self.edgeDiaSources["flags"], "flags") - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.edgeDiaSources, - self.exposure_time, flags) - np.testing.assert_array_equal(trailed_source_mask, [False, True, False, False, True]) - - # Mixing the flags from edgeDiaSources and diaSources means the mask - # will be set using both criteria. - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, - self.exposure_time, flags) - np.testing.assert_array_equal(trailed_source_mask, [False, True, False, True, True]) - - -class MemoryTester(lsst.utils.tests.MemoryTestCase): - pass - - -def setup_module(module): - lsst.utils.tests.init() - - -if __name__ == "__main__": - lsst.utils.tests.init() - unittest.main()