Skip to content

Commit

Permalink
Update filter parameters, docstrings, and unit tests
Browse files Browse the repository at this point in the history
Update

Flake8

Fix

Unit test fix
  • Loading branch information
bsmartradio committed Mar 21, 2024
1 parent 366cd22 commit b93c8dd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
19 changes: 6 additions & 13 deletions python/lsst/ap/association/trailedSourceFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig")

import os
import numpy as np

import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
Expand Down Expand Up @@ -119,18 +118,12 @@ def _check_dia_source_trail(self, dia_sources, exposure_time, flags):
-------
trail_mask : `pandas.DataFrame`
Boolean mask for DIASources which are greater than the
cutoff length or have off_image or suspect_long_trail
flag set.
cutoff length or have off_image setl. Also checks if both
suspect_long_trail and edge are set and masks those sources out.
"""
trail_mask = (dia_sources.loc[:, "trailLength"].values[:]
>= (self.config.max_trail_length*exposure_time))

long_flags = flags['ext_trailedSources_Naive_flag_suspect_long_trail']
edge_flags = flags['ext_trailedSources_Naive_flag_edge']

trail_mask[np.where(flags['ext_trailedSources_Naive_flag_off_image'])] = True
for index, value in enumerate(long_flags):
if value and edge_flags[index]:
trail_mask[index] = True
trail_mask = dia_sources.loc[:, "trailLength"] >= (self.config.max_trail_length*exposure_time)
trail_mask |= flags['ext_trailedSources_Naive_flag_off_image']
trail_mask |= (flags['ext_trailedSources_Naive_flag_suspect_long_trail']
& flags['ext_trailedSources_Naive_flag_edge'])

return trail_mask
16 changes: 11 additions & 5 deletions tests/test_trailedSourceFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def setUp(self):
"flags": 0}
for idx in range(self.nSources)])
self.exposure_time = 30.0
self.diaSources.loc[[2], 'flags'] = np.power(2, 35)

# For use only with testing the edge flag
self.edgeDiaSources = pd.DataFrame(data=[
Expand All @@ -58,6 +59,7 @@ def setUp(self):
for idx in range(self.nSources)])

self.edgeDiaSources.loc[[1, 4], 'flags'] = np.power(2, 27) + np.power(2, 36)
self.edgeDiaSources.loc[[2], 'flags'] = np.power(2, 35)

def test_run(self):
"""Run trailedSourceFilterTask with the default max distance.
Expand Down Expand Up @@ -109,19 +111,23 @@ def test_run_no_trails(self):
self.assertEqual(len(results.longTrailedDiaSources), 0)
np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4])
np.testing.assert_array_equal(results.longTrailedDiaSources["diaSourceId"].values, [])
np.testing.assert_array_equal(results.diaSources['flags'][2], np.power(2, 35))

def test_run_edge(self):
"""Run trailedSourceFilterTask with the default max distance.
filtered out of the final results and put into results.trailedSources.
Check that the two sources with the edge and suspect long trail flags
are correctly filtered out.
"""
trailedSourceFilterTask = TrailedSourceFilterTask()

results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time)
results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time)

# Only three sources should remain after filtering.
self.assertEqual(len(results.diaSources), 3)
np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 2, 3])
np.testing.assert_array_equal(results.longTrailedDiaSources['diaSourceId'].values, [1, 4])
# Check that the nan flagged source is not filtered out.
np.testing.assert_array_equal(results.diaSources['flags'][1], np.power(2, 35))

def test_check_dia_source_trail(self):
"""Test that the DiaSource trail checker is correctly identifying
Expand All @@ -131,13 +137,13 @@ def test_check_dia_source_trail(self):
"""
trailedSourceFilterTask = TrailedSourceFilterTask()
flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml")
# Sources which just have long trails in their dia source table.
unpacker = UnpackApdbFlags(flag_map, "DiaSource")
flags = unpacker.unpack(self.diaSources["flags"], "flags")
trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources,
self.exposure_time, flags)

np.testing.assert_array_equal(trailed_source_mask, [False, False, False, True, True])

# Sources which have no long trails but edge and suspect long trails.
flags = unpacker.unpack(self.edgeDiaSources["flags"], "flags")
trailed_source_mask = trailedSourceFilterTask._check_dia_source_trail(self.edgeDiaSources,
self.exposure_time, flags)
Expand Down

0 comments on commit b93c8dd

Please sign in to comment.