Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bsmartradio committed Oct 10, 2023
1 parent 520a6c4 commit 6aa4352
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 24 deletions.
13 changes: 9 additions & 4 deletions python/lsst/ap/association/trailedSourceFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@

__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig")

import os
import numpy as np

import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags
import os
import lsst.utils as utils


class TrailedSourceFilterConfig(pexConfig.Config):
Expand Down Expand Up @@ -82,7 +85,7 @@ def run(self, dia_sources, exposure_time):
(`pandas.DataFrame`)
"""

flag_map = os.path.join("${AP_ASSOCIATION_DIR}", "/data/association-flag-map.yaml")
flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml")
unpacker = UnpackApdbFlags(flag_map, "DiaSource")
flags = unpacker.unpack(dia_sources["flags"], "flags")

Expand Down Expand Up @@ -112,7 +115,9 @@ def _check_dia_source_trail(self, dia_sources, exposure_time, flags):
cutoff length.
"""
trail_mask = (dia_sources.loc[:, "trailLength"].values[:]
>= (self.config.max_trail_length*exposure_time)
or flags['ext_trailedSources_Naive_flag_edge'] is True)
>= (self.config.max_trail_length*exposure_time))

edge_loc = np.where(flags['ext_trailedSources_Naive_flag_edge'] is True)
trail_mask[edge_loc] = True

return trail_mask
8 changes: 5 additions & 3 deletions tests/test_association_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import numpy as np
import pandas as pd
import unittest

import lsst.geom as geom
import lsst.utils.tests

from lsst.ap.association import AssociationTask


Expand All @@ -45,12 +45,14 @@ def setUp(self):
self.diaSources = pd.DataFrame(data=[
{"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
"dec": 0.04*idx + scatter*rng.uniform(-1, 1),
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx,
"flags": 0}
for idx in range(self.nSources)])
self.diaSourceZeroScatter = pd.DataFrame(data=[
{"ra": 0.04*idx,
"dec": 0.04*idx,
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
"diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx,
"flags": 0}
for idx in range(self.nSources)])
self.exposure_time = 30.0

Expand Down
62 changes: 45 additions & 17 deletions tests/test_trailedSourceFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import unittest
from unittest.mock import patch, MagicMock
from lsst.ap.association import TrailedSourceFilterTask
import os
import numpy as np
import pandas as pd

import lsst.utils.tests
import pydevd
import time
import lsst.utils as utils
from lsst.ap.association import TrailedSourceFilterTask
from lsst.ap.association.transformDiaSourceCatalog import UnpackApdbFlags


class TestTrailedSourceFilterTask(unittest.TestCase):
Expand All @@ -44,13 +45,19 @@ def setUp(self):
{"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
"dec": 0.04*idx + scatter*rng.uniform(-1, 1),
"diaSourceId": idx, "diaObjectId": 0, "trailLength": 5.5*idx,
"flags": 4}
"flags": 0}
for idx in range(self.nSources)])
self.exposure_time = 30.0

while pydevd.get_global_debugger() is None or not pydevd.get_global_debugger().ready_to_run:
time.sleep(0.3)
breakpoint() # breaks here
# For use only with testing the edge flag`
self.edgeDiaSources = pd.DataFrame(data=[
{"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
"dec": 0.04*idx + scatter*rng.uniform(-1, 1),
"diaSourceId": idx, "diaObjectId": 0, "trailLength": 0,
"flags": 0}
for idx in range(self.nSources)])

self.edgeDiaSources.loc[[1, 4], 'flags'] = np.power(2, 27)

def test_run(self):
"""Run trailedSourceFilterTask with the default max distance.
Expand All @@ -60,18 +67,13 @@ def test_run(self):
filtered out of the final results and put into results.trailedSources.
"""
trailedSourceFilterTask = TrailedSourceFilterTask()
mockUnpacker = MagicMock()

mockUnpacker.return_value = [False, True, False]

with patch.object(lsst.ap.association.transformDiaSourceCatalog.UnpackApdbFlags, 'unpack', mockUnpacker):
results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)
results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)

self.assertEqual(len(results.diaSources), 3)
np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 1, 2])
np.testing.assert_array_equal(results.trailedDiaSources['diaSourceId'].values, [3, 4])

@unittest.skip("reason for skipping")
def test_run_short_max_trail(self):
"""Run trailedSourceFilterTask with aggressive trail length cutoff
Expand All @@ -89,7 +91,6 @@ def test_run_short_max_trail(self):
np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0])
np.testing.assert_array_equal(results.trailedDiaSources['diaSourceId'].values, [1, 2, 3, 4])

@unittest.skip("reason for skipping")
def test_run_no_trails(self):
"""Run trailedSourceFilterTask with a long trail length so that
every source in the catalog is in the final diaSource catalog.
Expand All @@ -109,16 +110,43 @@ def test_run_no_trails(self):
np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4])
np.testing.assert_array_equal(results.trailedDiaSources["diaSourceId"].values, [])

@unittest.skip("reason for skipping")
def test_run_edge(self):
"""Run trailedSourceFilterTask with the default max distance.
filtered out of the final results and put into results.trailedSources.
"""
trailedSourceFilterTask = TrailedSourceFilterTask()

results = trailedSourceFilterTask.run(self.edgeDiaSources, self.exposure_time)

self.assertEqual(len(results.diaSources), 3)
np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 2, 3])
np.testing.assert_array_equal(results.trailedDiaSources['diaSourceId'].values, [1, 4])

def test_check_dia_source_trail(self):
"""Test the source trail mask filter.
Test that the mask filter returns the expected mask array.
"""
trailedSourceFilterTask = TrailedSourceFilterTask()
mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, self.exposure_time)
flag_map = os.path.join(utils.getPackageDir("ap_association"), "data/association-flag-map.yaml")
unpacker = UnpackApdbFlags(flag_map, "DiaSource")
flags = unpacker.unpack(self.diaSources["flags"], "flags")
mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, self.exposure_time,
flags)

np.testing.assert_array_equal(mask, [False, False, False, True, True])

flags = unpacker.unpack(self.edgeDiaSources["flags"], "flags")
mask = trailedSourceFilterTask._check_dia_source_trail(self.edgeDiaSources, self.exposure_time,
flags)
np.testing.assert_array_equal(mask, [False, True, False, False, True])

# Mixing the flags from edgeDiaSources and diaSources means the mask
# will be set using both criteria
mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, self.exposure_time,
flags)
np.testing.assert_array_equal(mask, [False, True, False, True, True])


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit 6aa4352

Please sign in to comment.