From a272da69ed15120812c6d2c4f18ee26de706923f Mon Sep 17 00:00:00 2001 From: Brianna Smart Date: Fri, 19 Apr 2024 14:15:53 -0700 Subject: [PATCH] Move trailed source filtering to filterDiaSourceCatalog Move the trailed source filtering into filterDiaSourceCatalog. Additionally, update filtering so only trails that extend of the edge or which have long trails and are near the edge are filtered out, along with the sources which exceed the max trail length. --- data/association-flag-map.yaml | 5 +- python/lsst/ap/association/__init__.py | 1 - python/lsst/ap/association/association.py | 44 +---- python/lsst/ap/association/diaPipe.py | 26 +-- .../ap/association/filterDiaSourceCatalog.py | 117 ++++++++++++- .../ap/association/trailedSourceFilter.py | 125 ------------- tests/test_association_task.py | 24 +-- tests/test_diaPipe.py | 5 +- tests/test_filterDiaSourceCatalog.py | 82 ++++++++- tests/test_trailedSourceFilter.py | 165 ------------------ 10 files changed, 202 insertions(+), 392 deletions(-) delete mode 100644 python/lsst/ap/association/trailedSourceFilter.py delete mode 100644 tests/test_trailedSourceFilter.py diff --git a/data/association-flag-map.yaml b/data/association-flag-map.yaml index 3e8176d0..9b0c5492 100644 --- a/data/association-flag-map.yaml +++ b/data/association-flag-map.yaml @@ -82,9 +82,6 @@ columns: - name: slot_Shape_flag_parent_source bit: 26 doc: parent source, ignored; only valid for HsmShape - - name: ext_trailedSources_Naive_flag_edge - bit: 27 - doc: source is trailed and extends off chip - name: base_PixelFlags_flag_streak bit: 28 doc: Streak in the Source footprint @@ -102,4 +99,4 @@ columns: doc: Fake source template injection in source footprint - name: base_PixelFlags_flag_injected_templateCenter bit: 33 - doc: Fake source template injection center in source footprint + doc: Fake source template injection center in source footprint \ No newline at end of file diff --git a/python/lsst/ap/association/__init__.py b/python/lsst/ap/association/__init__.py index 16cd34da..fd84f9a7 100644 --- a/python/lsst/ap/association/__init__.py +++ b/python/lsst/ap/association/__init__.py @@ -25,6 +25,5 @@ from .loadDiaCatalogs import * from .packageAlerts import * from .diaPipe import * -from .trailedSourceFilter import * from .transformDiaSourceCatalog import * from .version import * \ No newline at end of file diff --git a/python/lsst/ap/association/association.py b/python/lsst/ap/association/association.py index d2bd043a..eb4e0e41 100644 --- a/python/lsst/ap/association/association.py +++ b/python/lsst/ap/association/association.py @@ -32,7 +32,6 @@ import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase from lsst.utils.timer import timeMethod -from .trailedSourceFilter import TrailedSourceFilterTask # Enforce an error for unsafe column/array value setting in pandas. pd.options.mode.chained_assignment = 'raise' @@ -49,19 +48,6 @@ class AssociationConfig(pexConfig.Config): default=1.0, ) - trailedSourceFilter = pexConfig.ConfigurableField( - target=TrailedSourceFilterTask, - doc="Subtask to remove long trailed sources based on catalog source " - "morphological measurements.", - ) - - doTrailedSourceFilter = pexConfig.Field( - doc="Run traildeSourceFilter to remove long trailed sources from " - "output catalog.", - dtype=bool, - default=True, - ) - class AssociationTask(pipeBase.Task): """Associate DIAOSources into existing DIAObjects. @@ -75,16 +61,10 @@ class AssociationTask(pipeBase.Task): ConfigClass = AssociationConfig _DefaultName = "association" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.config.doTrailedSourceFilter: - self.makeSubtask("trailedSourceFilter") - @timeMethod def run(self, diaSources, - diaObjects, - exposure_time=None): + diaObjects): """Associate the new DiaSources with existing DiaObjects. Parameters @@ -93,8 +73,6 @@ def run(self, New DIASources to be associated with existing DIAObjects. diaObjects : `pandas.DataFrame` Existing diaObjects from the Apdb. - exposure_time : `float`, optional - Exposure time from difference image. Returns ------- @@ -112,30 +90,15 @@ def run(self, matched to new DiaSources. (`int`) - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were not matched a new DiaSource. (`int`) - - ``longTrailedSources`` : DiaSources which have trail lengths - greater than max_trail_length/second*exposure_time. - (`pandas.DataFrame``) """ diaSources = self.check_dia_source_radec(diaSources) - if self.config.doTrailedSourceFilter: - diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time) - diaSources = diaTrailedResult.diaSources - longTrailedSources = diaTrailedResult.longTrailedDiaSources - - self.log.info("%i DiaSources exceed max_trail_length, dropping from source " - "catalog." % len(diaTrailedResult.longTrailedDiaSources)) - self.metadata.add("num_filtered", len(diaTrailedResult.longTrailedDiaSources)) - else: - longTrailedSources = pd.DataFrame(columns=diaSources.columns) - if len(diaObjects) == 0: return pipeBase.Struct( matchedDiaSources=pd.DataFrame(columns=diaSources.columns), unAssocDiaSources=diaSources, nUpdatedDiaObjects=0, - nUnassociatedDiaObjects=0, - longTrailedSources=longTrailedSources) + nUnassociatedDiaObjects=0) matchResult = self.associate_sources(diaObjects, diaSources) @@ -145,8 +108,7 @@ def run(self, matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True), unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True), nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects, - nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects, - longTrailedSources=longTrailedSources) + nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects) def check_dia_source_radec(self, dia_sources): """Check that all DiaSources have non-NaN values for RA/DEC. diff --git a/python/lsst/ap/association/diaPipe.py b/python/lsst/ap/association/diaPipe.py index 3629bfed..53faa63d 100644 --- a/python/lsst/ap/association/diaPipe.py +++ b/python/lsst/ap/association/diaPipe.py @@ -32,23 +32,22 @@ "DiaPipelineTask", "DiaPipelineConnections") -import numpy as np -import pandas as pd - -from lsst.daf.base import DateTime import lsst.dax.apdb as daxApdb -from lsst.meas.base import DetectorVisitIdGeneratorConfig, DiaObjectCalculationTask import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase import lsst.pipe.base.connectionTypes as connTypes -from lsst.utils.timer import timeMethod - +import numpy as np +import pandas as pd from lsst.ap.association import ( AssociationTask, DiaForcedSourceTask, LoadDiaCatalogsTask, PackageAlertsTask) from lsst.ap.association.ssoAssociation import SolarSystemAssociationTask +from lsst.daf.base import DateTime +from lsst.meas.base import DetectorVisitIdGeneratorConfig, \ + DiaObjectCalculationTask +from lsst.utils.timer import timeMethod class DiaPipelineConnections( @@ -116,12 +115,6 @@ class DiaPipelineConnections( storageClass="DataFrame", dimensions=("instrument", "visit", "detector"), ) - longTrailedSources = pipeBase.connectionTypes.Output( - doc="Optional output temporarily storing long trailed diaSources.", - dimensions=("instrument", "visit", "detector"), - storageClass="DataFrame", - name="{fakesType}{coaddName}Diff_longTrailedSrc", - ) def __init__(self, *, config=None): super().__init__(config=config) @@ -134,8 +127,6 @@ def __init__(self, *, config=None): self.outputs.remove("diaForcedSources") if not config.doSolarSystemAssociation: self.inputs.remove("solarSystemObjectTable") - if not config.associator.doTrailedSourceFilter: - self.outputs.remove("longTrailedSources") def adjustQuantum(self, inputs, outputs, label, dataId): """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects @@ -398,10 +389,8 @@ def run(self, buffer=self.config.imagePixelMargin) else: diaObjects = loaderResult.diaObjects - # Associate new DiaSources with existing DiaObjects. - assocResults = self.associator.run(diaSourceTable, diaObjects, - exposure_time=diffIm.visitInfo.exposureTime) + assocResults = self.associator.run(diaSourceTable, diaObjects) if self.config.doSolarSystemAssociation: ssoAssocResult = self.solarSystemAssociator.run( @@ -584,7 +573,6 @@ def run(self, associatedDiaSources=associatedDiaSources, diaForcedSources=diaForcedSources, diaObjects=diaObjects, - longTrailedSources=assocResults.longTrailedSources ) def createNewDiaObjects(self, unAssocDiaSources): diff --git a/python/lsst/ap/association/filterDiaSourceCatalog.py b/python/lsst/ap/association/filterDiaSourceCatalog.py index 9330e1e9..a7b476c9 100644 --- a/python/lsst/ap/association/filterDiaSourceCatalog.py +++ b/python/lsst/ap/association/filterDiaSourceCatalog.py @@ -61,10 +61,28 @@ class FilterDiaSourceCatalogConnections( dimensions={"instrument", "visit", "detector"}, ) + diffImVisitInfo = connTypes.Input( + doc="VisitInfo of diffIm.", + name="{fakesType}{coaddName}Diff_differenceExp.visitInfo", + storageClass="VisitInfo", + dimensions=("instrument", "visit", "detector"), + ) + + longTrailedSources = connTypes.Output( + doc="Optional output temporarily storing long trailed diaSources.", + dimensions=("instrument", "visit", "detector"), + storageClass="ArrowAstropy", + name="{fakesType}{coaddName}Diff_longTrailedSrc", + ) + def __init__(self, *, config=None): super().__init__(config=config) - if not self.config.doWriteRejectedSources: + if not self.config.doWriteRejectedSkySources: self.outputs.remove("rejectedDiaSources") + if not self.config.doTrailedSourceFilter: + self.outputs.remove("longTrailedSources") + if not self.config.doWriteTrailedSources: + self.outputs.remove("longTrailedSources") class FilterDiaSourceCatalogConfig( @@ -79,13 +97,37 @@ class FilterDiaSourceCatalogConfig( "removed before storing the output DiaSource catalog.", ) - doWriteRejectedSources = pexConfig.Field( + doWriteRejectedSkySources = pexConfig.Field( dtype=bool, default=True, doc="Store the output DiaSource catalog containing all the rejected " "sky sources." ) + doTrailedSourceFilter = pexConfig.Field( + doc="Run trailedSourceFilter to remove long trailed sources from " + "output catalog.", + dtype=bool, + default=True, + ) + + doWriteTrailedSources = pexConfig.Field( + doc="Write trailed sources to a table.", + dtype=bool, + default=True, + deprecated="Trailed sources will not be written out during production." + ) + + max_trail_length = pexConfig.Field( + dtype=float, + doc="Length of long trailed sources to remove from the input catalog, " + "in arcseconds per second. Default comes from DMTN-199, which " + "requires removal of sources with trails longer than 10 " + "degrees/day, which is 36000/3600/24 arcsec/second, or roughly" + "0.416 arcseconds per second.", + default=36000/3600.0/24.0, + ) + class FilterDiaSourceCatalogTask(pipeBase.PipelineTask): """Filter out sky sources from a DiaSource catalog.""" @@ -94,13 +136,15 @@ class FilterDiaSourceCatalogTask(pipeBase.PipelineTask): _DefaultName = "filterDiaSourceCatalog" @timeMethod - def run(self, diaSourceCat): + def run(self, diaSourceCat, diffImVisitInfo): """Filter sky sources from the supplied DiaSource catalog. Parameters ---------- diaSourceCat : `lsst.afw.table.SourceCatalog` Catalog of sources measured on the difference image. + diffImVisitInfo: `lsst.afw.image.VisitInfo` + Visit information for a particular exposure. Returns ------- @@ -109,9 +153,13 @@ def run(self, diaSourceCat): ``filteredDiaSourceCat`` : `lsst.afw.table.SourceCatalog` The catalog of filtered sources. ``rejectedDiaSources`` : `lsst.afw.table.SourceCatalog` - The catalog of rejected sources. + The catalog of rejected sky sources. + ``longTrailedDiaSources`` : `astropy.table.Table` + DiaSources which have trail lengths greater than + max_trail_length/second*exposure_time. """ rejectedSkySources = None + exposure_time = diffImVisitInfo.exposureTime if self.config.doRemoveSkySources: sky_source_column = diaSourceCat["sky_source"] num_sky_sources = np.sum(sky_source_column) @@ -120,6 +168,63 @@ def run(self, diaSourceCat): self.log.info(f"Filtered {num_sky_sources} sky sources.") if not rejectedSkySources: rejectedSkySources = SourceCatalog(diaSourceCat.getSchema()) - filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, - rejectedDiaSources=rejectedSkySources) + + if self.config.doTrailedSourceFilter: + trail_mask = self._check_dia_source_trail(diaSourceCat, exposure_time) + longTrailedDiaSources = diaSourceCat[trail_mask].copy(deep=True) + diaSourceCat = diaSourceCat[~trail_mask] + + self.log.info( + "%i DiaSources exceed max_trail_length, dropping from source " + "catalog." % len(diaSourceCat)) + self.metadata.add("num_filtered", len(longTrailedDiaSources)) + + if self.config.doWriteTrailedSources: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources, + longTrailedSources=longTrailedDiaSources.asAstropy()) + else: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources) + + else: + filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, + rejectedDiaSources=rejectedSkySources) + return filterResults + + def _check_dia_source_trail(self, dia_sources, exposure_time): + """Find DiaSources that have long trails or trails with indeterminant + end points. + + Return a mask of sources with lengths greater than + ``config.max_trail_length`` multiplied by the exposure time in + seconds. Additionally, set mask if + ``ext_trailedSources_Naive_flag_off_image`` is set or if + ``ext_trailedSources_Naive_flag_suspect_long_trail`` and + ``ext_trailedSources_Naive_flag_edge`` are both set. + + Parameters + ---------- + dia_sources : `pandas.DataFrame` + Input DIASources to check for trail lengths. + exposure_time : `float` + Exposure time from difference image. + + Returns + ------- + trail_mask : `pandas.DataFrame` + Boolean mask for DIASources which are greater than the + Boolean mask for DIASources which are greater than the + cutoff length or have trails which extend beyond the edge of the + detector (off_image set). Also checks if both + suspect_long_trail and edge are set and masks those sources out. + """ + print(dia_sources.getSchema()) + trail_mask = (dia_sources["ext_trailedSources_Naive_length"] + >= (self.config.max_trail_length*exposure_time)) + trail_mask |= dia_sources['ext_trailedSources_Naive_flag_off_image'] + trail_mask |= (dia_sources['ext_trailedSources_Naive_flag_suspect_long_trail'] + & dia_sources['ext_trailedSources_Naive_flag_edge']) + + return trail_mask diff --git a/python/lsst/ap/association/trailedSourceFilter.py b/python/lsst/ap/association/trailedSourceFilter.py deleted file mode 100644 index d7721928..00000000 --- a/python/lsst/ap/association/trailedSourceFilter.py +++ /dev/null @@ -1,125 +0,0 @@ -# This file is part of ap_association. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig") - -import os -import numpy as np - -import lsst.pex.config as pexConfig -import lsst.pipe.base as pipeBase -from lsst.utils.timer import timeMethod -from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags -import lsst.utils as utils - - -class TrailedSourceFilterConfig(pexConfig.Config): - """Config class for TrailedSourceFilterTask. - """ - - max_trail_length = pexConfig.Field( - dtype=float, - doc="Length of long trailed sources to remove from the input catalog, " - "in arcseconds per second. Default comes from DMTN-199, which " - "requires removal of sources with trails longer than 10 " - "degrees/day, which is 36000/3600/24 arcsec/second, or roughly" - "0.416 arcseconds per second.", - default=36000/3600.0/24.0, - ) - - -class TrailedSourceFilterTask(pipeBase.Task): - """Find trailed sources in DIASources and filter them as per DMTN-199 - guidelines. - - This task checks the length of trailLength in the DIASource catalog using - a given arcsecond/second rate from max_trail_length and the exposure time. - The two values are used to calculate the maximum allowed trail length and - filters out any trail longer than the maximum. The max_trail_length is - outlined in DMTN-199 and determines the default value. - """ - - ConfigClass = TrailedSourceFilterConfig - _DefaultName = "trailedSourceFilter" - - @timeMethod - def run(self, dia_sources, exposure_time): - """Remove trailed sources longer than ``config.max_trail_length`` from - the input catalog. - - Parameters - ---------- - dia_sources : `pandas.DataFrame` - New DIASources to be checked for trailed sources. - exposure_time : `float` - Exposure time from difference image. - - Returns - ------- - result : `lsst.pipe.base.Struct` - Results struct with components. - - - ``diaSources`` : DIASource table that is free from unwanted - trailed sources. (`pandas.DataFrame`) - - - ``longTrailedDiaSources`` : DIASources that have trails which - exceed max_trail_length/second*exposure_time (seconds). - (`pandas.DataFrame`) - """ - - flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flag_map, "DiaSource") - flags = unpacker.unpack(dia_sources["flags"], "flags") - - trail_mask = self._check_dia_source_trail(dia_sources, exposure_time, flags) - - return pipeBase.Struct( - diaSources=dia_sources[~trail_mask].reset_index(drop=True), - longTrailedDiaSources=dia_sources[trail_mask].reset_index(drop=True)) - - def _check_dia_source_trail(self, dia_sources, exposure_time, flags): - """Find DiaSources that have long trails. - - Return a mask of sources with lengths greater than - ``config.max_trail_length`` multiplied by the exposure time in seconds - or have ext_trailedSources_Naive_flag_edge set. - - Parameters - ---------- - dia_sources : `pandas.DataFrame` - Input DIASources to check for trail lengths. - exposure_time : `float` - Exposure time from difference image. - flags : 'numpy.ndArray' - Boolean array of flags from the DIASources. - - Returns - ------- - trail_mask : `pandas.DataFrame` - Boolean mask for DIASources which are greater than the - cutoff length and have the edge flag set. - """ - trail_mask = (dia_sources.loc[:, "trailLength"].values[:] - >= (self.config.max_trail_length*exposure_time)) - - trail_mask[np.where(flags['ext_trailedSources_Naive_flag_edge'])] = True - - return trail_mask diff --git a/tests/test_association_task.py b/tests/test_association_task.py index 22b3d8fa..62873d73 100644 --- a/tests/test_association_task.py +++ b/tests/test_association_task.py @@ -54,16 +54,14 @@ def setUp(self): "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx, "flags": 0} for idx in range(self.nSources)]) - self.exposure_time = 30.0 def test_run(self): """Test the full task by associating a set of diaSources to existing diaObjects. """ config = AssociationTask.ConfigClass() - config.doTrailedSourceFilter = False assocTask = AssociationTask(config=config) - results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) + results = assocTask.run(self.diaSources, self.diaObjects) self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1) self.assertEqual(results.nUnassociatedDiaObjects, 1) @@ -73,31 +71,13 @@ def test_run(self): np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4]) np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) - def test_run_trailed_sources(self): - """Test the full task by associating a set of diaSources to - existing diaObjects when trailed sources are filtered. - - This should filter out two of the five sources based on trail length, - leaving one unassociated diaSource and two associated diaSources. - """ - assocTask = AssociationTask() - results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) - - self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3) - self.assertEqual(results.nUnassociatedDiaObjects, 3) - self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3) - self.assertEqual(len(results.unAssocDiaSources), 1) - np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2]) - np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) - def test_run_no_existing_objects(self): """Test the run method with a completely empty database. """ assocTask = AssociationTask() results = assocTask.run( self.diaSources, - pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]), - exposure_time=self.exposure_time) + pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"])) self.assertEqual(results.nUpdatedDiaObjects, 0) self.assertEqual(results.nUnassociatedDiaObjects, 0) self.assertEqual(len(results.matchedDiaSources), 0) diff --git a/tests/test_diaPipe.py b/tests/test_diaPipe.py index 69c789e9..de9c9d87 100644 --- a/tests/test_diaPipe.py +++ b/tests/test_diaPipe.py @@ -139,11 +139,10 @@ def solarSystemAssociator_run(self, unAssocDiaSources, solarSystemObjectTable, d unAssocDiaSources=_makeMockDataFrame()) @lsst.utils.timer.timeMethod - def associator_run(self, table, diaObjects, exposure_time=None): + def associator_run(self, table, diaObjects): return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3, matchedDiaSources=_makeMockDataFrame(), - unAssocDiaSources=_makeMockDataFrame(), - longTrailedSources=None) + unAssocDiaSources=_makeMockDataFrame()) # apdb isn't a subtask, but still needs to be mocked out for correct # execution in the test environment. diff --git a/tests/test_filterDiaSourceCatalog.py b/tests/test_filterDiaSourceCatalog.py index 73cf1b5f..62dc8643 100644 --- a/tests/test_filterDiaSourceCatalog.py +++ b/tests/test_filterDiaSourceCatalog.py @@ -26,12 +26,14 @@ import lsst.geom as geom import lsst.meas.base.tests as measTests import lsst.utils.tests +import lsst.afw.image as afwImage +import lsst.daf.base as dafBase class TestFilterDiaSourceCatalogTask(unittest.TestCase): def setUp(self): - self.nSources = 10 + self.nSources = 15 self.nSkySources = 5 self.yLoc = 100 self.expId = 4321 @@ -42,24 +44,59 @@ def setUp(self): dataset.addSource(10000.0, geom.Point2D(srcIdx, self.yLoc)) schema = dataset.makeMinimalSchema() schema.addField("sky_source", type="Flag", doc="Sky objects.") + schema.addField('ext_trailedSources_Naive_flag_off_image', type="Flag", doc="off_image") + schema.addField('ext_trailedSources_Naive_flag_suspect_long_trail', + type="Flag", doc="suspect_long_trail") + schema.addField('ext_trailedSources_Naive_flag_edge', type="Flag", doc="edge") + schema.addField('ext_trailedSources_Naive_flag_nan', type="Flag", doc="nan") + schema.addField('ext_trailedSources_Naive_length', type="F", doc="trail length") _, self.diaSourceCat = dataset.realize(10.0, schema, randomSeed=1234) self.diaSourceCat[0:self.nSkySources]["sky_source"] = True + self.nLongTrailedSources = 0 + for srcIdx in range(5, 15): + self.diaSourceCat[srcIdx]["ext_trailedSources_Naive_length"] = 1.5*(srcIdx-4) + if 1.5*(srcIdx-4) > 36000/3600.0/24.0 * 30.0: + self.nLongTrailedSources += 1 + self.diaSourceCat[5]["ext_trailedSources_Naive_flag_off_image"] = True + self.diaSourceCat[6]["ext_trailedSources_Naive_flag_suspect_long_trail"] = True + self.nLongTrailedSources += 2 + self.diaSourceCat[6]["ext_trailedSources_Naive_flag_edge"] = True self.config = FilterDiaSourceCatalogConfig() + mjd = 57071.0 + self.utc_jd = mjd + 2_400_000.5 - 35.0 / (24.0 * 60.0 * 60.0) + + self.visitInfo = afwImage.VisitInfo( + # Incomplete VisitInfo; Python constructor allows any value to + # be defaulted. + exposureTime=30.0, + darkTime=3.0, + date=dafBase.DateTime(mjd, system=dafBase.DateTime.MJD), + boresightRaDec=geom.SpherePoint(0.0, 0.0, geom.degrees), + ) def test_run_without_filter(self): + """Test that when all filters are turned off all sources in the catalog + are returned. + """ self.config.doRemoveSkySources = False - self.config.doWriteRejectedSources = True + self.config.doWriteRejectedSkySources = False + self.config.doTrailedSourceFilter = False filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) - result = filterDiaSourceCatalogTask.run(self.diaSourceCat) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat)) self.assertEqual(len(result.rejectedDiaSources), 0) self.assertEqual(len(self.diaSourceCat), self.nSources) - def test_run_with_filter(self): + def test_run_with_filter_sky_only(self): + """Test that when only the sky filter is turned on the first five + sources which are flagged as sky objects are filtered out of the + catalog and the rest are returned. + """ self.config.doRemoveSkySources = True - self.config.doWriteRejectedSources = True + self.config.doWriteRejectedSkySources = True + self.config.doTrailedSourceFilter = False filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) - result = filterDiaSourceCatalogTask.run(self.diaSourceCat) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) nExpectedFilteredSources = self.nSources - self.nSkySources self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat[~self.diaSourceCat['sky_source']])) @@ -67,6 +104,39 @@ def test_run_with_filter(self): self.assertEqual(len(result.rejectedDiaSources), self.nSkySources) self.assertEqual(len(self.diaSourceCat), self.nSources) + def test_run_with_filter_trailed_sources_only(self): + """Test that when only the trail filter is turned on the the last three + sources which have long trails, the source with both the suspect trail + and edge trail flag are set, and the source which is flagged as + continuing off edge are all filtered out. All sky objects should remain + in the catalog. + """ + self.config.doRemoveSkySources = False + self.config.doWriteRejectedSkySources = False + self.config.doTrailedSourceFilter = True + filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) + nExpectedFilteredSources = self.nSources - self.nLongTrailedSources + self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) + self.assertEqual(len(self.diaSourceCat), self.nSources) + + def test_run_with_all_filters(self): + """Test that all sources are filtered out correctly. Only six sources + should remain in the catalog after filtering. + """ + self.config.doRemoveSkySources = True + self.config.doWriteRejectedSkySources = True + self.config.doTrailedSourceFilter = True + filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) + nExpectedFilteredSources = self.nSources - self.nSkySources - self.nLongTrailedSources + # 5 filtered out sky sources + # 4 filtered out trailed sources, 2 with long trails 2 with flags + # 6 sources left + self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) + self.assertEqual(len(result.rejectedDiaSources), self.nSkySources) + self.assertEqual(len(self.diaSourceCat), self.nSources) + class MemoryTester(lsst.utils.tests.MemoryTestCase): pass diff --git a/tests/test_trailedSourceFilter.py b/tests/test_trailedSourceFilter.py deleted file mode 100644 index 3dfc1974..00000000 --- a/tests/test_trailedSourceFilter.py +++ /dev/null @@ -1,165 +0,0 @@ -# This file is part of ap_association. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -import unittest -import os -import numpy as np -import pandas as pd - -import lsst.utils.tests -import lsst.utils as utils -from lsst.ap.association import TrailedSourceFilterTask -from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags - - -class TestTrailedSourceFilterTask(unittest.TestCase): - - def setUp(self): - """Create sets of diaSources. - - The trail lengths of the dia sources are 0, 5.5, 11, 16.5, 21.5 - arcseconds. - """ - rng = np.random.default_rng(1234) - scatter = 0.1 / 3600 - self.nSources = 5 - self.diaSources = pd.DataFrame(data=[ - {"ra": 0.04*idx + scatter*rng.uniform(-1, 1), - "dec": 0.04*idx + scatter*rng.uniform(-1, 1), - "diaSourceId": idx, "diaObjectId": 0, "trailLength": 5.5*idx, - "flags": 0} - for idx in range(self.nSources)]) - self.exposure_time = 30.0 - - # For use only with testing the edge flag - self.edgeDiaSources = pd.DataFrame(data=[ - {"ra": 0.04*idx + scatter*rng.uniform(-1, 1), - "dec": 0.04*idx + scatter*rng.uniform(-1, 1), - "diaSourceId": idx, "diaObjectId": 0, "trailLength": 0, - "flags": 0} - for idx in range(self.nSources)]) - - flagMap = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flagMap, "DiaSource") - bitMask = unpacker.makeFlagBitMask(["ext_trailedSources_Naive_flag_edge"]) - # Flag two sources as "trailed on the edge". - self.edgeDiaSources.loc[[1, 4], "flags"] |= bitMask - - def test_run(self): - """Run trailedSourceFilterTask with the default max distance. - - With the default settings and an exposure of 30 seconds, the max trail - length is 12.5 arcseconds. Two out of five of the diaSources will be - filtered out of the final results and put into results.trailedSources. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 3) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 1, 2]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [3, 4]) - - def test_run_short_max_trail(self): - """Run trailedSourceFilterTask with aggressive trail length cutoff - - With a max_trail_length config of 0.01 arcseconds/second and an - exposure of 30 seconds,the max trail length is 0.3 arcseconds. Only the - source with a trail of 0 stays in the catalog and the rest are filtered - out and put into results.trailedSources. - """ - config = TrailedSourceFilterTask.ConfigClass() - config.max_trail_length = 0.01 - trailedSourceFilterTask = TrailedSourceFilterTask(config=config) - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 1) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 2, 3, 4]) - - def test_run_no_trails(self): - """Run trailedSourceFilterTask with a long trail length so that - every source in the catalog is in the final diaSource catalog. - - With a max_trail_length config of 10 arcseconds/second and an - exposure of 30 seconds,the max trail length is 300 arcseconds. All - sources in the initial catalog should be in the final diaSource - catalog. - """ - config = TrailedSourceFilterTask.ConfigClass() - config.max_trail_length = 10.00 - trailedSourceFilterTask = TrailedSourceFilterTask(config=config) - results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 5) - self.assertEqual(len(results.longTrailedDiaSources), 0) - np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4]) - np.testing.assert_array_equal(results.longTrailedDiaSources["diaSourceId"].values, []) - - def test_run_edge(self): - """Run trailedSourceFilterTask on a source on the edge. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - - results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time) - - self.assertEqual(len(results.diaSources), 3) - np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 2, 3]) - np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 4]) - - def test_check_dia_source_trail(self): - """Test that the DiaSource trail checker is correctly identifying - long trails - - Test that the trail source mask filter returns the expected mask array. - """ - trailedSourceFilterTask = TrailedSourceFilterTask() - flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") - unpacker = UnpackApdbFlags(flag_map, "DiaSource") - flags = unpacker.unpack(self.diaSources["flags"], "flags") - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, - self.exposure_time, flags) - - np.testing.assert_array_equal(trailed_source_mask, [False, False, False, True, True]) - - flags = unpacker.unpack(self.edgeDiaSources["flags"], "flags") - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.edgeDiaSources, - self.exposure_time, flags) - np.testing.assert_array_equal(trailed_source_mask, [False, True, False, False, True]) - - # Mixing the flags from edgeDiaSources and diaSources means the mask - # will be set using both criteria. - trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, - self.exposure_time, flags) - np.testing.assert_array_equal(trailed_source_mask, [False, True, False, True, True]) - - -class MemoryTester(lsst.utils.tests.MemoryTestCase): - pass - - -def setup_module(module): - lsst.utils.tests.init() - - -if __name__ == "__main__": - lsst.utils.tests.init() - unittest.main()