diff --git a/data/association-flag-map.yaml b/data/association-flag-map.yaml
index 3e8176d0..af7edc42 100644
--- a/data/association-flag-map.yaml
+++ b/data/association-flag-map.yaml
@@ -102,4 +102,4 @@ columns:
doc: Fake source template injection in source footprint
- name: base_PixelFlags_flag_injected_templateCenter
bit: 33
- doc: Fake source template injection center in source footprint
+ doc: Fake source template injection center in source footprint
\ No newline at end of file
diff --git a/python/lsst/ap/association/__init__.py b/python/lsst/ap/association/__init__.py
index 16cd34da..fd84f9a7 100644
--- a/python/lsst/ap/association/__init__.py
+++ b/python/lsst/ap/association/__init__.py
@@ -25,6 +25,5 @@
from .loadDiaCatalogs import *
from .packageAlerts import *
from .diaPipe import *
-from .trailedSourceFilter import *
from .transformDiaSourceCatalog import *
from .version import *
\ No newline at end of file
diff --git a/python/lsst/ap/association/association.py b/python/lsst/ap/association/association.py
index d2bd043a..eb4e0e41 100644
--- a/python/lsst/ap/association/association.py
+++ b/python/lsst/ap/association/association.py
@@ -32,7 +32,6 @@
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
-from .trailedSourceFilter import TrailedSourceFilterTask
# Enforce an error for unsafe column/array value setting in pandas.
pd.options.mode.chained_assignment = 'raise'
@@ -49,19 +48,6 @@ class AssociationConfig(pexConfig.Config):
default=1.0,
)
- trailedSourceFilter = pexConfig.ConfigurableField(
- target=TrailedSourceFilterTask,
- doc="Subtask to remove long trailed sources based on catalog source "
- "morphological measurements.",
- )
-
- doTrailedSourceFilter = pexConfig.Field(
- doc="Run traildeSourceFilter to remove long trailed sources from "
- "output catalog.",
- dtype=bool,
- default=True,
- )
-
class AssociationTask(pipeBase.Task):
"""Associate DIAOSources into existing DIAObjects.
@@ -75,16 +61,10 @@ class AssociationTask(pipeBase.Task):
ConfigClass = AssociationConfig
_DefaultName = "association"
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- if self.config.doTrailedSourceFilter:
- self.makeSubtask("trailedSourceFilter")
-
@timeMethod
def run(self,
diaSources,
- diaObjects,
- exposure_time=None):
+ diaObjects):
"""Associate the new DiaSources with existing DiaObjects.
Parameters
@@ -93,8 +73,6 @@ def run(self,
New DIASources to be associated with existing DIAObjects.
diaObjects : `pandas.DataFrame`
Existing diaObjects from the Apdb.
- exposure_time : `float`, optional
- Exposure time from difference image.
Returns
-------
@@ -112,30 +90,15 @@ def run(self,
matched to new DiaSources. (`int`)
- ``nUnassociatedDiaObjects`` : Number of DiaObjects that were
not matched a new DiaSource. (`int`)
- - ``longTrailedSources`` : DiaSources which have trail lengths
- greater than max_trail_length/second*exposure_time.
- (`pandas.DataFrame``)
"""
diaSources = self.check_dia_source_radec(diaSources)
- if self.config.doTrailedSourceFilter:
- diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time)
- diaSources = diaTrailedResult.diaSources
- longTrailedSources = diaTrailedResult.longTrailedDiaSources
-
- self.log.info("%i DiaSources exceed max_trail_length, dropping from source "
- "catalog." % len(diaTrailedResult.longTrailedDiaSources))
- self.metadata.add("num_filtered", len(diaTrailedResult.longTrailedDiaSources))
- else:
- longTrailedSources = pd.DataFrame(columns=diaSources.columns)
-
if len(diaObjects) == 0:
return pipeBase.Struct(
matchedDiaSources=pd.DataFrame(columns=diaSources.columns),
unAssocDiaSources=diaSources,
nUpdatedDiaObjects=0,
- nUnassociatedDiaObjects=0,
- longTrailedSources=longTrailedSources)
+ nUnassociatedDiaObjects=0)
matchResult = self.associate_sources(diaObjects, diaSources)
@@ -145,8 +108,7 @@ def run(self,
matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True),
unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True),
nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects,
- nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects,
- longTrailedSources=longTrailedSources)
+ nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects)
def check_dia_source_radec(self, dia_sources):
"""Check that all DiaSources have non-NaN values for RA/DEC.
diff --git a/python/lsst/ap/association/diaPipe.py b/python/lsst/ap/association/diaPipe.py
index 9ee2d556..3aacd4b3 100644
--- a/python/lsst/ap/association/diaPipe.py
+++ b/python/lsst/ap/association/diaPipe.py
@@ -35,23 +35,22 @@
import warnings
-import numpy as np
-import pandas as pd
-
-from lsst.daf.base import DateTime
import lsst.dax.apdb as daxApdb
-from lsst.meas.base import DetectorVisitIdGeneratorConfig, DiaObjectCalculationTask
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import lsst.pipe.base.connectionTypes as connTypes
-from lsst.utils.timer import timeMethod
-
+import numpy as np
+import pandas as pd
from lsst.ap.association import (
AssociationTask,
DiaForcedSourceTask,
LoadDiaCatalogsTask,
PackageAlertsTask)
from lsst.ap.association.ssoAssociation import SolarSystemAssociationTask
+from lsst.daf.base import DateTime
+from lsst.meas.base import DetectorVisitIdGeneratorConfig, \
+ DiaObjectCalculationTask
+from lsst.utils.timer import timeMethod
class DiaPipelineConnections(
@@ -119,12 +118,6 @@ class DiaPipelineConnections(
storageClass="DataFrame",
dimensions=("instrument", "visit", "detector"),
)
- longTrailedSources = pipeBase.connectionTypes.Output(
- doc="Optional output temporarily storing long trailed diaSources.",
- dimensions=("instrument", "visit", "detector"),
- storageClass="DataFrame",
- name="{fakesType}{coaddName}Diff_longTrailedSrc",
- )
def __init__(self, *, config=None):
super().__init__(config=config)
@@ -137,8 +130,6 @@ def __init__(self, *, config=None):
self.outputs.remove("diaForcedSources")
if not config.doSolarSystemAssociation:
self.inputs.remove("solarSystemObjectTable")
- if not config.associator.doTrailedSourceFilter:
- self.outputs.remove("longTrailedSources")
def adjustQuantum(self, inputs, outputs, label, dataId):
"""Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
@@ -438,10 +429,8 @@ def run(self,
buffer=self.config.imagePixelMargin)
else:
diaObjects = loaderResult.diaObjects
-
# Associate new DiaSources with existing DiaObjects.
- assocResults = self.associator.run(diaSourceTable, diaObjects,
- exposure_time=diffIm.visitInfo.exposureTime)
+ assocResults = self.associator.run(diaSourceTable, diaObjects)
if self.config.doSolarSystemAssociation:
ssoAssocResult = self.solarSystemAssociator.run(
@@ -627,7 +616,6 @@ def run(self,
associatedDiaSources=associatedDiaSources,
diaForcedSources=diaForcedSources,
diaObjects=diaObjects,
- longTrailedSources=assocResults.longTrailedSources
)
def createNewDiaObjects(self, unAssocDiaSources):
diff --git a/python/lsst/ap/association/filterDiaSourceCatalog.py b/python/lsst/ap/association/filterDiaSourceCatalog.py
index 9330e1e9..b6b985f7 100644
--- a/python/lsst/ap/association/filterDiaSourceCatalog.py
+++ b/python/lsst/ap/association/filterDiaSourceCatalog.py
@@ -61,10 +61,28 @@ class FilterDiaSourceCatalogConnections(
dimensions={"instrument", "visit", "detector"},
)
+ diffImVisitInfo = connTypes.Input(
+ doc="VisitInfo of diffIm.",
+ name="{fakesType}{coaddName}Diff_differenceExp.visitInfo",
+ storageClass="VisitInfo",
+ dimensions=("instrument", "visit", "detector"),
+ )
+
+ longTrailedSources = connTypes.Output(
+ doc="Optional output temporarily storing long trailed diaSources.",
+ dimensions=("instrument", "visit", "detector"),
+ storageClass="ArrowAstropy",
+ name="{fakesType}{coaddName}Diff_longTrailedSrc",
+ )
+
def __init__(self, *, config=None):
super().__init__(config=config)
- if not self.config.doWriteRejectedSources:
+ if not self.config.doWriteRejectedSkySources:
self.outputs.remove("rejectedDiaSources")
+ if not self.config.doTrailedSourceFilter:
+ self.outputs.remove("longTrailedSources")
+ if not self.config.doWriteTrailedSources:
+ self.outputs.remove("longTrailedSources")
class FilterDiaSourceCatalogConfig(
@@ -79,13 +97,37 @@ class FilterDiaSourceCatalogConfig(
"removed before storing the output DiaSource catalog.",
)
- doWriteRejectedSources = pexConfig.Field(
+ doWriteRejectedSkySources = pexConfig.Field(
dtype=bool,
default=True,
doc="Store the output DiaSource catalog containing all the rejected "
"sky sources."
)
+ doTrailedSourceFilter = pexConfig.Field(
+ doc="Run trailedSourceFilter to remove long trailed sources from the"
+ "diaSource output catalog.",
+ dtype=bool,
+ default=True,
+ )
+
+ doWriteTrailedSources = pexConfig.Field(
+ doc="Write trailed diaSources sources to a table.",
+ dtype=bool,
+ default=True,
+ deprecated="Trailed sources will not be written out during production."
+ )
+
+ max_trail_length = pexConfig.Field(
+ dtype=float,
+ doc="Length of long trailed sources to remove from the input catalog, "
+ "in arcseconds per second. Default comes from DMTN-199, which "
+ "requires removal of sources with trails longer than 10 "
+ "degrees/day, which is 36000/3600/24 arcsec/second, or roughly"
+ "0.416 arcseconds per second.",
+ default=36000/3600.0/24.0,
+ )
+
class FilterDiaSourceCatalogTask(pipeBase.PipelineTask):
"""Filter out sky sources from a DiaSource catalog."""
@@ -94,13 +136,15 @@ class FilterDiaSourceCatalogTask(pipeBase.PipelineTask):
_DefaultName = "filterDiaSourceCatalog"
@timeMethod
- def run(self, diaSourceCat):
+ def run(self, diaSourceCat, diffImVisitInfo):
"""Filter sky sources from the supplied DiaSource catalog.
Parameters
----------
diaSourceCat : `lsst.afw.table.SourceCatalog`
Catalog of sources measured on the difference image.
+ diffImVisitInfo: `lsst.afw.image.VisitInfo`
+ VisitInfo for the difference image corresponding to diaSourceCat.
Returns
-------
@@ -109,9 +153,13 @@ def run(self, diaSourceCat):
``filteredDiaSourceCat`` : `lsst.afw.table.SourceCatalog`
The catalog of filtered sources.
``rejectedDiaSources`` : `lsst.afw.table.SourceCatalog`
- The catalog of rejected sources.
+ The catalog of rejected sky sources.
+ ``longTrailedDiaSources`` : `astropy.table.Table`
+ DiaSources which have trail lengths greater than
+ max_trail_length*exposure_time.
"""
rejectedSkySources = None
+ exposure_time = diffImVisitInfo.exposureTime
if self.config.doRemoveSkySources:
sky_source_column = diaSourceCat["sky_source"]
num_sky_sources = np.sum(sky_source_column)
@@ -120,6 +168,64 @@ def run(self, diaSourceCat):
self.log.info(f"Filtered {num_sky_sources} sky sources.")
if not rejectedSkySources:
rejectedSkySources = SourceCatalog(diaSourceCat.getSchema())
- filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat,
- rejectedDiaSources=rejectedSkySources)
+
+ if self.config.doTrailedSourceFilter:
+ trail_mask = self._check_dia_source_trail(diaSourceCat, exposure_time)
+ longTrailedDiaSources = diaSourceCat[trail_mask].copy(deep=True)
+ diaSourceCat = diaSourceCat[~trail_mask]
+
+ self.log.info("%i DiaSources exceed max_trail_length %f arcseconds per second, "
+ "dropping from source catalog."
+ % (self.config.max_trail_length, len(diaSourceCat)))
+ self.metadata.add("num_filtered", len(longTrailedDiaSources))
+
+ if self.config.doWriteTrailedSources:
+ filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat,
+ rejectedDiaSources=rejectedSkySources,
+ longTrailedSources=longTrailedDiaSources.asAstropy())
+ else:
+ filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat,
+ rejectedDiaSources=rejectedSkySources)
+
+ else:
+ filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat,
+ rejectedDiaSources=rejectedSkySources)
+
return filterResults
+
+ def _check_dia_source_trail(self, dia_sources, exposure_time):
+ """Find DiaSources that have long trails or trails with indeterminant
+ end points.
+
+ Return a mask of sources with lengths greater than
+ (``config.max_trail_length`` multiplied by the exposure time)
+ arcseconds.
+ Additionally, set mask if
+ ``ext_trailedSources_Naive_flag_off_image`` is set or if
+ ``ext_trailedSources_Naive_flag_suspect_long_trail`` and
+ ``ext_trailedSources_Naive_flag_edge`` are both set.
+
+ Parameters
+ ----------
+ dia_sources : `pandas.DataFrame`
+ Input diaSources to check for trail lengths.
+ exposure_time : `float`
+ Exposure time from difference image.
+
+ Returns
+ -------
+ trail_mask : `pandas.DataFrame`
+ Boolean mask for diaSources which are greater than the
+ Boolean mask for diaSources which are greater than the
+ cutoff length or have trails which extend beyond the edge of the
+ detector (off_image set). Also checks if both
+ suspect_long_trail and edge are set and masks those sources out.
+ """
+ print(dia_sources.getSchema())
+ trail_mask = (dia_sources["ext_trailedSources_Naive_length"]
+ >= (self.config.max_trail_length*exposure_time))
+ trail_mask |= dia_sources['ext_trailedSources_Naive_flag_off_image']
+ trail_mask |= (dia_sources['ext_trailedSources_Naive_flag_suspect_long_trail']
+ & dia_sources['ext_trailedSources_Naive_flag_edge'])
+
+ return trail_mask
diff --git a/python/lsst/ap/association/trailedSourceFilter.py b/python/lsst/ap/association/trailedSourceFilter.py
deleted file mode 100644
index 8408ffb2..00000000
--- a/python/lsst/ap/association/trailedSourceFilter.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# This file is part of ap_association.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (https://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig")
-
-import os
-import numpy as np
-
-import lsst.pex.config as pexConfig
-import lsst.pipe.base as pipeBase
-from lsst.utils.timer import timeMethod
-from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags
-import lsst.utils as utils
-
-
-class TrailedSourceFilterConfig(pexConfig.Config):
- """Config class for TrailedSourceFilterTask.
- """
-
- max_trail_length = pexConfig.Field(
- dtype=float,
- doc="Length of long trailed sources to remove from the input catalog, "
- "in arcseconds per second. Default comes from DMTN-199, which "
- "requires removal of sources with trails longer than 10 "
- "degrees/day, which is 36000/3600/24 arcsec/second, or roughly"
- "0.416 arcseconds per second.",
- default=36000/3600.0/24.0,
- )
-
-
-class TrailedSourceFilterTask(pipeBase.Task):
- """Find trailed sources in DIASources and filter them as per DMTN-199
- guidelines.
-
- This task checks the length of trailLength in the DIASource catalog using
- a given arcsecond/second rate from max_trail_length and the exposure time.
- The two values are used to calculate the maximum allowed trail length and
- filters out any trail longer than the maximum. The max_trail_length is
- outlined in DMTN-199 and determines the default value.
- """
-
- ConfigClass = TrailedSourceFilterConfig
- _DefaultName = "trailedSourceFilter"
-
- @timeMethod
- def run(self, dia_sources, exposure_time):
- """Remove trailed sources longer than ``config.max_trail_length`` from
- the input catalog.
-
- Parameters
- ----------
- dia_sources : `pandas.DataFrame`
- New DIASources to be checked for trailed sources.
- exposure_time : `float`
- Exposure time from difference image.
-
- Returns
- -------
- result : `lsst.pipe.base.Struct`
- Results struct with components.
-
- - ``diaSources`` : DIASource table that is free from unwanted
- trailed sources. (`pandas.DataFrame`)
-
- - ``longTrailedDiaSources`` : DIASources that have trails which
- exceed max_trail_length/second*exposure_time (seconds).
- (`pandas.DataFrame`)
- """
-
- if "flags" in dia_sources.columns:
- flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml")
- unpacker = UnpackApdbFlags(flag_map, "DiaSource")
- flags = unpacker.unpack(dia_sources["flags"], "flags")
- trail_edge_flags = flags["ext_trailedSources_Naive_flag_edge"]
- else:
- trail_edge_flags = dia_sources["trail_flag_edge"]
-
- trail_mask = self._check_dia_source_trail(dia_sources, exposure_time, trail_edge_flags)
-
- return pipeBase.Struct(
- diaSources=dia_sources[~trail_mask].reset_index(drop=True),
- longTrailedDiaSources=dia_sources[trail_mask].reset_index(drop=True))
-
- def _check_dia_source_trail(self, dia_sources, exposure_time, trail_edge_flags):
- """Find DiaSources that have long trails.
-
- Return a mask of sources with lengths greater than
- ``config.max_trail_length`` multiplied by the exposure time in seconds
- or have ext_trailedSources_Naive_flag_edge set.
-
- Parameters
- ----------
- dia_sources : `pandas.DataFrame`
- Input DIASources to check for trail lengths.
- exposure_time : `float`
- Exposure time from difference image.
- trail_edge_flags : 'numpy.ndArray'
- Boolean array of trail_flag_edge flags from the DIASources.
-
- Returns
- -------
- trail_mask : `pandas.DataFrame`
- Boolean mask for DIASources which are greater than the
- cutoff length or have the edge flag set.
- """
- trail_mask = (dia_sources.loc[:, "trailLength"].values[:]
- >= (self.config.max_trail_length*exposure_time))
-
- trail_mask[np.where(trail_edge_flags)] = True
-
- return trail_mask
diff --git a/tests/test_association_task.py b/tests/test_association_task.py
index 22b3d8fa..62873d73 100644
--- a/tests/test_association_task.py
+++ b/tests/test_association_task.py
@@ -54,16 +54,14 @@ def setUp(self):
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx,
"flags": 0}
for idx in range(self.nSources)])
- self.exposure_time = 30.0
def test_run(self):
"""Test the full task by associating a set of diaSources to
existing diaObjects.
"""
config = AssociationTask.ConfigClass()
- config.doTrailedSourceFilter = False
assocTask = AssociationTask(config=config)
- results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)
+ results = assocTask.run(self.diaSources, self.diaObjects)
self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1)
self.assertEqual(results.nUnassociatedDiaObjects, 1)
@@ -73,31 +71,13 @@ def test_run(self):
np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4])
np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])
- def test_run_trailed_sources(self):
- """Test the full task by associating a set of diaSources to
- existing diaObjects when trailed sources are filtered.
-
- This should filter out two of the five sources based on trail length,
- leaving one unassociated diaSource and two associated diaSources.
- """
- assocTask = AssociationTask()
- results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)
-
- self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3)
- self.assertEqual(results.nUnassociatedDiaObjects, 3)
- self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3)
- self.assertEqual(len(results.unAssocDiaSources), 1)
- np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2])
- np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])
-
def test_run_no_existing_objects(self):
"""Test the run method with a completely empty database.
"""
assocTask = AssociationTask()
results = assocTask.run(
self.diaSources,
- pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]),
- exposure_time=self.exposure_time)
+ pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]))
self.assertEqual(results.nUpdatedDiaObjects, 0)
self.assertEqual(results.nUnassociatedDiaObjects, 0)
self.assertEqual(len(results.matchedDiaSources), 0)
diff --git a/tests/test_diaPipe.py b/tests/test_diaPipe.py
index 1f3ec49e..edaed345 100644
--- a/tests/test_diaPipe.py
+++ b/tests/test_diaPipe.py
@@ -183,11 +183,10 @@ def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, diffIm)
ssoAssocDiaSources=_makeMockDataFrame(),
unAssocDiaSources=_makeMockDataFrame())
- def associator_run(table, diaObjects, exposure_time=None):
+ def associator_run(table, diaObjects):
return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3,
matchedDiaSources=_makeMockDataFrame(),
- unAssocDiaSources=_makeMockDataFrame(),
- longTrailedSources=None)
+ unAssocDiaSources=_makeMockDataFrame())
# apdb isn't a subtask, but still needs to be mocked out for correct
# execution in the test environment.
diff --git a/tests/test_filterDiaSourceCatalog.py b/tests/test_filterDiaSourceCatalog.py
index 73cf1b5f..b813a3bd 100644
--- a/tests/test_filterDiaSourceCatalog.py
+++ b/tests/test_filterDiaSourceCatalog.py
@@ -26,12 +26,14 @@
import lsst.geom as geom
import lsst.meas.base.tests as measTests
import lsst.utils.tests
+import lsst.afw.image as afwImage
+import lsst.daf.base as dafBase
class TestFilterDiaSourceCatalogTask(unittest.TestCase):
def setUp(self):
- self.nSources = 10
+ self.nSources = 15
self.nSkySources = 5
self.yLoc = 100
self.expId = 4321
@@ -42,24 +44,70 @@ def setUp(self):
dataset.addSource(10000.0, geom.Point2D(srcIdx, self.yLoc))
schema = dataset.makeMinimalSchema()
schema.addField("sky_source", type="Flag", doc="Sky objects.")
+ schema.addField('ext_trailedSources_Naive_flag_off_image', type="Flag",
+ doc="Trail extends off image")
+ schema.addField('ext_trailedSources_Naive_flag_suspect_long_trail',
+ type="Flag", doc="Trail length is greater than three times the psf radius")
+ schema.addField('ext_trailedSources_Naive_flag_edge', type="Flag",
+ doc="Trail contains edge pixels")
+ schema.addField('ext_trailedSources_Naive_flag_nan', type="Flag",
+ doc="One or more trail coordinates are missing")
+ schema.addField('ext_trailedSources_Naive_length', type="F",
+ doc="Length of the source trail")
_, self.diaSourceCat = dataset.realize(10.0, schema, randomSeed=1234)
self.diaSourceCat[0:self.nSkySources]["sky_source"] = True
+ # The last 10 sources will all contained trail length measurements,
+ # increasing in size by 1.5 arcseconds. Only the last three will have
+ # lengths which are too long and will be filtered out.
+ self.nFilteredTrailedSources = 0
+ for srcIdx in range(5, 15):
+ self.diaSourceCat[srcIdx]["ext_trailedSources_Naive_length"] = 1.5*(srcIdx-4)
+ if 1.5*(srcIdx-4) > 36000/3600.0/24.0 * 30.0:
+ self.nFilteredTrailedSources += 1
+ # Setting a combination of flags for filtering in tests
+ self.diaSourceCat[5]["ext_trailedSources_Naive_flag_off_image"] = True
+ self.diaSourceCat[6]["ext_trailedSources_Naive_flag_suspect_long_trail"] = True
+ self.diaSourceCat[6]["ext_trailedSources_Naive_flag_edge"] = True
+ # As only two of these flags are set, the total number of filtered
+ # sources will be self.nFilteredTrailedSources + 2
+ self.nFilteredTrailedSources += 2
self.config = FilterDiaSourceCatalogConfig()
+ mjd = 57071.0
+ self.utc_jd = mjd + 2_400_000.5 - 35.0 / (24.0 * 60.0 * 60.0)
+
+ self.visitInfo = afwImage.VisitInfo(
+ # This incomplete visitInfo is sufficient for testing because the
+ # Python constructor sets all other required values to some
+ # default.
+ exposureTime=30.0,
+ darkTime=3.0,
+ date=dafBase.DateTime(mjd, system=dafBase.DateTime.MJD),
+ boresightRaDec=geom.SpherePoint(0.0, 0.0, geom.degrees),
+ )
def test_run_without_filter(self):
+ """Test that when all filters are turned off all sources in the catalog
+ are returned.
+ """
self.config.doRemoveSkySources = False
- self.config.doWriteRejectedSources = True
+ self.config.doWriteRejectedSkySources = False
+ self.config.doTrailedSourceFilter = False
filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
- result = filterDiaSourceCatalogTask.run(self.diaSourceCat)
+ result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat))
self.assertEqual(len(result.rejectedDiaSources), 0)
self.assertEqual(len(self.diaSourceCat), self.nSources)
- def test_run_with_filter(self):
+ def test_run_with_filter_sky_only(self):
+ """Test that when only the sky filter is turned on the first five
+ sources which are flagged as sky objects are filtered out of the
+ catalog and the rest are returned.
+ """
self.config.doRemoveSkySources = True
- self.config.doWriteRejectedSources = True
+ self.config.doWriteRejectedSkySources = True
+ self.config.doTrailedSourceFilter = False
filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
- result = filterDiaSourceCatalogTask.run(self.diaSourceCat)
+ result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
nExpectedFilteredSources = self.nSources - self.nSkySources
self.assertEqual(len(result.filteredDiaSourceCat),
len(self.diaSourceCat[~self.diaSourceCat['sky_source']]))
@@ -67,6 +115,39 @@ def test_run_with_filter(self):
self.assertEqual(len(result.rejectedDiaSources), self.nSkySources)
self.assertEqual(len(self.diaSourceCat), self.nSources)
+ def test_run_with_filter_trailed_sources_only(self):
+ """Test that when only the trail filter is turned on the correct number
+ of sources are filtered out. The filtered sources should be the last
+ three sources which have long trails, one source where both the suspect
+ trail and edge trail flag are set, and one source where off_image is
+ set. All sky objects should remain in the catalog.
+ """
+ self.config.doRemoveSkySources = False
+ self.config.doWriteRejectedSkySources = False
+ self.config.doTrailedSourceFilter = True
+ filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
+ result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
+ nExpectedFilteredSources = self.nSources - self.nFilteredTrailedSources
+ self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
+ self.assertEqual(len(self.diaSourceCat), self.nSources)
+
+ def test_run_with_all_filters(self):
+ """Test that all sources are filtered out correctly. Only six sources
+ should remain in the catalog after filtering.
+ """
+ self.config.doRemoveSkySources = True
+ self.config.doWriteRejectedSkySources = True
+ self.config.doTrailedSourceFilter = True
+ filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
+ result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
+ nExpectedFilteredSources = self.nSources - self.nSkySources - self.nFilteredTrailedSources
+ # 5 filtered out sky sources
+ # 4 filtered out trailed sources, 2 with long trails 2 with flags
+ # 6 sources left
+ self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
+ self.assertEqual(len(result.rejectedDiaSources), self.nSkySources)
+ self.assertEqual(len(self.diaSourceCat), self.nSources)
+
class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
diff --git a/tests/test_trailedSourceFilter.py b/tests/test_trailedSourceFilter.py
deleted file mode 100644
index e777aaea..00000000
--- a/tests/test_trailedSourceFilter.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# This file is part of ap_association.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (https://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-import unittest
-import os
-import numpy as np
-import pandas as pd
-
-import lsst.utils.tests
-import lsst.utils as utils
-from lsst.ap.association import TrailedSourceFilterTask
-from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags
-
-
-class TestTrailedSourceFilterTask(unittest.TestCase):
-
- def setUp(self):
- """Create sets of diaSources.
-
- The trail lengths of the dia sources are 0, 5.5, 11, 16.5, 21.5
- arcseconds.
- """
- # Create an instance of random generator with fixed seed.
- rng = np.random.default_rng(1234)
-
- scatter = 0.1 / 3600
- self.nSources = 5
- self.diaSources = pd.DataFrame(data=[
- {"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
- "dec": 0.04*idx + scatter*rng.uniform(-1, 1),
- "diaSourceId": idx, "diaObjectId": 0, "trailLength": 5.5*idx,
- "flags": 0}
- for idx in range(self.nSources)])
- self.exposure_time = 30.0
-
- # For use only with testing the edge flag
- self.edgeDiaSources = pd.DataFrame(data=[
- {"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
- "dec": 0.04*idx + scatter*rng.uniform(-1, 1),
- "diaSourceId": idx, "diaObjectId": 0, "trailLength": 0,
- "flags": 0}
- for idx in range(self.nSources)])
-
- flagMap = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml")
- unpacker = UnpackApdbFlags(flagMap, "DiaSource")
- bitMask = unpacker.makeFlagBitMask(["ext_trailedSources_Naive_flag_edge"])
- # Flag two sources as "trailed on the edge".
- self.edgeDiaSources.loc[[1, 4], "flags"] |= bitMask
-
- def test_run(self):
- """Run trailedSourceFilterTask with the default max distance.
-
- With the default settings and an exposure of 30 seconds, the max trail
- length is 12.5 arcseconds. Two out of five of the diaSources will be
- filtered out of the final results and put into results.trailedSources.
- """
- trailedSourceFilterTask = TrailedSourceFilterTask()
-
- results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)
-
- self.assertEqual(len(results.diaSources), 3)
- np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 1, 2])
- np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [3, 4])
-
- def test_run_short_max_trail(self):
- """Run trailedSourceFilterTask with aggressive trail length cutoff
-
- With a max_trail_length config of 0.01 arcseconds/second and an
- exposure of 30 seconds,the max trail length is 0.3 arcseconds. Only the
- source with a trail of 0 stays in the catalog and the rest are filtered
- out and put into results.trailedSources.
- """
- config = TrailedSourceFilterTask.ConfigClass()
- config.max_trail_length = 0.01
- trailedSourceFilterTask = TrailedSourceFilterTask(config=config)
- results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)
-
- self.assertEqual(len(results.diaSources), 1)
- np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0])
- np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 2, 3, 4])
-
- def test_run_no_trails(self):
- """Run trailedSourceFilterTask with a long trail length so that
- every source in the catalog is in the final diaSource catalog.
-
- With a max_trail_length config of 10 arcseconds/second and an
- exposure of 30 seconds,the max trail length is 300 arcseconds. All
- sources in the initial catalog should be in the final diaSource
- catalog.
- """
- config = TrailedSourceFilterTask.ConfigClass()
- config.max_trail_length = 10.00
- trailedSourceFilterTask = TrailedSourceFilterTask(config=config)
- results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)
-
- self.assertEqual(len(results.diaSources), 5)
- self.assertEqual(len(results.longTrailedDiaSources), 0)
- np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4])
- np.testing.assert_array_equal(results.longTrailedDiaSources["diaSourceId"].values, [])
-
- def test_run_edge(self):
- """Run trailedSourceFilterTask on a source on the edge.
- """
- trailedSourceFilterTask = TrailedSourceFilterTask()
-
- results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time)
-
- self.assertEqual(len(results.diaSources), 3)
- np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 2, 3])
- np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 4])
-
- def test_check_dia_source_trail(self):
- """Test that the DiaSource trail checker is correctly identifying
- long trails
-
- Test that the trail source mask filter returns the expected mask array.
- """
- trailedSourceFilterTask = TrailedSourceFilterTask()
- flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml")
- unpacker = UnpackApdbFlags(flag_map, "DiaSource")
- flags = unpacker.unpack(self.diaSources["flags"], "flags")
- trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources,
- self.exposure_time, flags)
-
- np.testing.assert_array_equal(trailed_source_mask, [False, False, False, True, True])
-
- flags = unpacker.unpack(self.edgeDiaSources["flags"], "flags")
- trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.edgeDiaSources,
- self.exposure_time, flags)
- np.testing.assert_array_equal(trailed_source_mask, [False, True, False, False, True])
-
- # Mixing the flags from edgeDiaSources and diaSources means the mask
- # will be set using both criteria.
- trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources,
- self.exposure_time, flags)
- np.testing.assert_array_equal(trailed_source_mask, [False, True, False, True, True])
-
-
-class MemoryTester(lsst.utils.tests.MemoryTestCase):
- pass
-
-
-def setup_module(module):
- lsst.utils.tests.init()
-
-
-if __name__ == "__main__":
- lsst.utils.tests.init()
- unittest.main()