Skip to content

Commit

Permalink
Fix tests after loadCatalogsTask refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
isullivan committed Aug 1, 2024
1 parent 3d46fb8 commit 182a670
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 14 deletions.
18 changes: 13 additions & 5 deletions tests/test_diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
import lsst.afw.image as afwImage
import lsst.afw.table as afwTable
import lsst.dax.apdb as daxApdb
from lsst.meas.base import IdGenerator
import lsst.pex.config as pexConfig
import lsst.utils.tests
from lsst.pipe.base.testUtils import assertValidOutput

from lsst.ap.association import DiaPipelineTask
from utils_tests import makeExposure, makeDiaObjects
from utils_tests import makeExposure, makeDiaObjects, makeDiaSources, makeDiaForcedSources


def _makeMockDataFrame():
Expand Down Expand Up @@ -77,6 +78,12 @@ def setUp(self):
srcSchema.addField("base_PixelFlags_flag", type="Flag")
srcSchema.addField("base_PixelFlags_flag_offimage", type="Flag")
self.srcSchema = afwTable.SourceCatalog(srcSchema)
self.exposure = makeExposure(False, False)
self.diaObjects = makeDiaObjects(20, self.exposure, rng)
self.diaSources = makeDiaSources(
100, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
self.diaForcedSources = makeDiaForcedSources(
200, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)

apdb_config = daxApdb.ApdbSql.init_database(db_url="sqlite://")
self.config_file = tempfile.NamedTemporaryFile()
Expand Down Expand Up @@ -154,13 +161,11 @@ def _testRun(self, doPackageAlerts=False, doSolarSystemAssociation=False):
template = Mock(spec=afwImage.ExposureF)
diaSrc = _makeMockDataFrame()
ssObjects = _makeMockDataFrame()
ccdExposureIdBits = 32

# Each of these subtasks should be called once during diaPipe
# execution. We use mocks here to check they are being executed
# appropriately.
subtasksToMock = [
"diaCatalogLoader",
"diaCalculation",
"diaForcedSource",
]
Expand Down Expand Up @@ -202,8 +207,11 @@ def associator_run(table, diaObjects):
diffIm,
exposure,
template,
ccdExposureIdBits,
"g")
self.diaObjects,
self.diaSources,
self.diaForcedSources,
"g",
IdGenerator())
for subtaskName in subtasksToMock:
getattr(task, subtaskName).run.assert_called_once()
assertValidOutput(task, result)
Expand Down
29 changes: 20 additions & 9 deletions tests/test_loadDiaCatalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
import unittest
import yaml

from lsst.ap.association import LoadDiaCatalogsTask, LoadDiaCatalogsConfig
from lsst.ap.association import LoadDiaCatalogsTask
from lsst.ap.association.utils import readSchemaFromApdb
from lsst.dax.apdb import Apdb, ApdbSql, ApdbTables
from lsst.utils import getPackageDir
import lsst.utils.tests
from utils_tests import makeExposure, makeDiaObjects, makeDiaSources, makeDiaForcedSources
from utils_tests import makeExposure, makeDiaObjects, makeDiaSources, makeDiaForcedSources, makeRegionTime


def _data_file_name(basename, module_name):
Expand Down Expand Up @@ -64,19 +64,23 @@ def setUp(self):
self.addCleanup(os.close, self.db_file_fd)

self.apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + self.db_file)
self.config_file = tempfile.NamedTemporaryFile()
self.addCleanup(self.config_file.close)
self.apdbConfig.save(self.config_file.name)
self.apdb = Apdb.from_config(self.apdbConfig)
self.schema = readSchemaFromApdb(self.apdb)

self.exposure = makeExposure(False, False)
self.regionTime = makeRegionTime(exposure=self.exposure)

self.diaObjects = makeDiaObjects(20, self.exposure, rng)
self.diaSources = makeDiaSources(
100, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
self.diaForcedSources = makeDiaForcedSources(
200, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)

self.dateTime = self.exposure.visitInfo.date
self.apdb.store(self.dateTime.toAstropy(),
self.dateTime = self.regionTime.timespan.begin.tai
self.apdb.store(self.dateTime,
self.diaObjects,
self.diaSources,
self.diaForcedSources)
Expand All @@ -88,8 +92,10 @@ def setUp(self):
def testRun(self):
"""Test the full run method for the loader.
"""
diaLoader = LoadDiaCatalogsTask()
result = diaLoader.run(self.exposure, self.apdb)
diaConfig = LoadDiaCatalogsTask.ConfigClass()
diaConfig.apdb_config_url = self.config_file.name
diaLoader = LoadDiaCatalogsTask(config=diaConfig)
result = diaLoader.run(self.regionTime)

self.assertEqual(len(result.diaObjects), len(self.diaObjects))
self.assertEqual(len(result.diaSources), len(self.diaSources))
Expand All @@ -99,7 +105,9 @@ def testRun(self):
def testLoadDiaObjects(self):
"""Test that the correct number of diaObjects are loaded.
"""
diaLoader = LoadDiaCatalogsTask()
diaConfig = LoadDiaCatalogsTask.ConfigClass()
diaConfig.apdb_config_url = self.config_file.name
diaLoader = LoadDiaCatalogsTask(config=diaConfig)
region = diaLoader._getRegion(self.exposure)
diaObjects = diaLoader.loadDiaObjects(region,
self.apdb,
Expand All @@ -109,7 +117,9 @@ def testLoadDiaObjects(self):
def testLoadDiaForcedSources(self):
"""Test that the correct number of diaForcedSources are loaded.
"""
diaLoader = LoadDiaCatalogsTask()
diaConfig = LoadDiaCatalogsTask.ConfigClass()
diaConfig.apdb_config_url = self.config_file.name
diaLoader = LoadDiaCatalogsTask(config=diaConfig)
region = diaLoader._getRegion(self.exposure)
diaForcedSources = diaLoader.loadDiaForcedSources(
self.diaObjects,
Expand All @@ -125,7 +135,8 @@ def testLoadDiaSources(self):
Also check that they can be properly loaded both by location and
``diaObjectId``.
"""
diaConfig = LoadDiaCatalogsConfig()
diaConfig = LoadDiaCatalogsTask.ConfigClass()
diaConfig.apdb_config_url = self.config_file.name
diaLoader = LoadDiaCatalogsTask(config=diaConfig)

region = diaLoader._getRegion(self.exposure)
Expand Down
35 changes: 35 additions & 0 deletions tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""Helper functions for tests of DIA catalogs, including generating mock
catalogs for simulated APDB access.
"""
import astropy.units
import datetime
import pandas as pd
import numpy as np
Expand All @@ -30,7 +31,10 @@
import lsst.afw.geom as afwGeom
import lsst.afw.image as afwImage
import lsst.daf.base as dafBase
import lsst.daf.butler as dafButler
import lsst.geom
from lsst.pipe.base.utils import RegionTimeInfo
import lsst.sphgeom


def makeDiaObjects(nObjects, exposure, rng):
Expand Down Expand Up @@ -237,3 +241,34 @@ def makeExposure(flipX=False, flipY=False):
exposure.setFilter(afwImage.FilterLabel(band='g'))

return exposure


def makeRegionTime(exposure=None):
if exposure is None:
exposure = makeExposure()
region = getRegion(exposure)
begin = exposure.visitInfo.date.toAstropy()
end = begin + exposure.visitInfo.exposureTime*astropy.units.second
timespan = dafButler.Timespan(begin=begin, end=end)
return RegionTimeInfo(region=region, timespan=timespan)


def getRegion(exposure):
"""Calculate an enveloping region for an exposure.
Parameters
----------
exposure : `lsst.afw.image.Exposure`
Exposure object with calibrated WCS.
Returns
-------
region : `lsst.sphgeom.Region`
Region enveloping an exposure.
"""
bbox = lsst.geom.Box2D(exposure.getBBox())
wcs = exposure.getWcs()

region = lsst.sphgeom.ConvexPolygon([wcs.pixelToSky(pp).getVector()
for pp in bbox.getCorners()])
return region

0 comments on commit 182a670

Please sign in to comment.