From c2b6447f333faab971eb4e873556e694183febd0 Mon Sep 17 00:00:00 2001 From: Brianna Smart Date: Mon, 4 Mar 2024 13:38:41 -0800 Subject: [PATCH] Update filter parameters, docstrings, and unit tests Update Flake8 --- .../ap/association/trailedSourceFilter.py | 19 ++++++------------- tests/test_trailedSourceFilter.py | 15 ++++++++++----- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/python/lsst/ap/association/trailedSourceFilter.py b/python/lsst/ap/association/trailedSourceFilter.py index 13c48b6b..6adceefe 100644 --- a/python/lsst/ap/association/trailedSourceFilter.py +++ b/python/lsst/ap/association/trailedSourceFilter.py @@ -22,7 +22,6 @@ __all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig") import os -import numpy as np import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase @@ -119,18 +118,12 @@ def _check_dia_source_trail(self, dia_sources, exposure_time, flags): ------- trail_mask : `pandas.DataFrame` Boolean mask for DIASources which are greater than the - cutoff length or have off_image or suspect_long_trail - flag set. + cutoff length or have off_image setl. Also checks if both + suspect_long_trail and edge are set and masks those sources out. """ - trail_mask = (dia_sources.loc[:, "trailLength"].values[:] - >= (self.config.max_trail_length*exposure_time)) - - long_flags = flags['ext_trailedSources_Naive_flag_suspect_long_trail'] - edge_flags = flags['ext_trailedSources_Naive_flag_edge'] - - trail_mask[np.where(flags['ext_trailedSources_Naive_flag_off_image'])] = True - for index, value in enumerate(long_flags): - if value and edge_flags[index]: - trail_mask[index] = True + trail_mask = dia_sources.loc[:, "trailLength"] >= (self.config.max_trail_length*exposure_time) + trail_mask |= flags['ext_trailedSources_Naive_flag_off_image'] + trail_mask |= (flags['ext_trailedSources_Naive_flag_suspect_long_trail'] + & flags['ext_trailedSources_Naive_flag_edge']) return trail_mask diff --git a/tests/test_trailedSourceFilter.py b/tests/test_trailedSourceFilter.py index 6b2c02d6..4743ac63 100644 --- a/tests/test_trailedSourceFilter.py +++ b/tests/test_trailedSourceFilter.py @@ -58,6 +58,7 @@ def setUp(self): for idx in range(self.nSources)]) self.edgeDiaSources.loc[[1, 4], 'flags'] = np.power(2, 27) + np.power(2, 36) + self.edgeDiaSources.loc[[2], 'flags'] = np.power(2, 35) def test_run(self): """Run trailedSourceFilterTask with the default max distance. @@ -109,19 +110,23 @@ def test_run_no_trails(self): self.assertEqual(len(results.longTrailedDiaSources), 0) np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4]) np.testing.assert_array_equal(results.longTrailedDiaSources["diaSourceId"].values, []) + np.testing.assert_array_equal(results.diaSources['flags'][2].values, np.power(2, 35)) def test_run_edge(self): """Run trailedSourceFilterTask with the default max distance. - filtered out of the final results and put into results.trailedSources. + + Check that the two sources with the edge and suspect long trail flags + are correctly filtered out. """ trailedSourceFilterTask = TrailedSourceFilterTask() - - results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time) results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time) + # Only three sources should remain after filtering. self.assertEqual(len(results.diaSources), 3) np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 2, 3]) np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 4]) + # Check that the nan flagged source is not filtered out. + np.testing.assert_array_equal(results.diaSources['flags'][1].values, np.power(2, 35)) def test_check_dia_source_trail(self): """Test that the DiaSource trail checker is correctly identifying @@ -131,13 +136,13 @@ def test_check_dia_source_trail(self): """ trailedSourceFilterTask = TrailedSourceFilterTask() flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml") + # Sources which just have long trails in their dia source table. unpacker = UnpackApdbFlags(flag_map, "DiaSource") flags = unpacker.unpack(self.diaSources["flags"], "flags") trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, self.exposure_time, flags) - np.testing.assert_array_equal(trailed_source_mask, [False, False, False, True, True]) - + # Sources which have no long trails but edge and suspect long trails. flags = unpacker.unpack(self.edgeDiaSources["flags"], "flags") trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.edgeDiaSources, self.exposure_time, flags)