Skip to content

Commit

Permalink
fixup again, moving home to test decamimdiff gen
Browse files Browse the repository at this point in the history
  • Loading branch information
DinoBektesevic committed Aug 6, 2024
1 parent 866cfaf commit 8c9ae11
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 58 deletions.
29 changes: 9 additions & 20 deletions src/kbmod/mocking/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,14 @@ class DECamImdiff:
def from_defaults(
cls,
with_data=False,
override_original=False,
shape=(100, 100),
start_mjd=60310,
step_mjd=0.001,
with_noise=False,
noise="simplistic",
src_cat=None,
obj_cat=None,
editable_images=False,
separate_masks=False,
writeable_masks=False,
editable_masks=False,
obj_cat=None
):
if obj_cat.config.type == "progressive":
raise ValueError(
Expand All @@ -202,15 +203,8 @@ def from_defaults(
hdr_factory = ArchivedHeader("headers_archive.tar.bz2", "decam_imdiff_headers.ecsv")

hdu_types = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU]
hdu_types.extend(
[
BinTableHDU,
]
* 12
)
data = [
NoneFactory(),
] * 16
hdu_types.extend([BinTableHDU] * 12)
data = [NoneFactory()] * 16

if with_data:
headers = hdr_factory.get(0)
Expand All @@ -236,12 +230,7 @@ def __init__(self, header_factory, data_factories=None, obj_cat=None):
self.hdr_factory = header_factory
self.data_factories = data_factories
self.hdu_layout = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU]
self.hdu_layout.extend(
[
BinTableHDU,
]
* 12
)
self.hdu_layout.extend([BinTableHDU] * 12)

def mock(self, n=1):
obj_cats = None
Expand Down
14 changes: 8 additions & 6 deletions src/kbmod/mocking/fits_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,13 @@ class SimpleImage(DataFactory):

default_config = SimpleImageConfig

def __init__(self, image=None, src_cat=None, obj_cat=None, config=None, **kwargs):
def __init__(self, image=None, src_cat=None, obj_cat=None, config=None,
dtype=np.float32, **kwargs):
self.config = self.default_config(config=config, **kwargs)
super().__init__(image, self.config, **kwargs)

if image is None:
image = np.zeros(self.config.shape, dtype=np.float32)
image = np.zeros(self.config.shape, dtype=dtype)
else:
image = image
self.config.shape = image.shape
Expand Down Expand Up @@ -794,7 +795,7 @@ def add_noise(cls, images, config):
return images

@classmethod
def gen_base_image(cls, config=None, src_cat=None):
def gen_base_image(cls, config=None, src_cat=None, dtype=np.float32):
"""Generate base image from configuration.
Parameters
Expand All @@ -812,7 +813,7 @@ def gen_base_image(cls, config=None, src_cat=None):
config = cls.default_config(config)

# empty image
base = np.zeros(config.shape, dtype=np.float32)
base = np.zeros(config.shape, dtype=dtype)
base += config.bias
base = cls.add_hot_pixels(base, config)
base = cls.add_bad_cols(base, config)
Expand All @@ -821,7 +822,8 @@ def gen_base_image(cls, config=None, src_cat=None):

return base

def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, **kwargs):
def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, dtype=np.float32,**kwargs):
conf = self.default_config(config=config, **kwargs)
# static objects are added in SimpleImage init
super().__init__(image=self.gen_base_image(conf), config=conf, src_cat=src_cat, obj_cat=obj_cat)
super().__init__(image=self.gen_base_image(conf, dtype=dtype),
config=conf, src_cat=src_cat, obj_cat=obj_cat)
90 changes: 64 additions & 26 deletions src/kbmod/mocking/headers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings

from astropy.utils.exceptions import AstropyUserWarning
import numpy as np

from astropy.wcs import WCS
from astropy.io.fits import Header

Expand All @@ -14,6 +15,54 @@
]


def make_wcs(center_coords=(351., -5.), rotation=0, pixscale=0.2, shape=None):
"""
Create a simple celestial `~astropy.wcs.WCS` object in ICRS
coordinate system.
Parameters
----------
shape : tuple[int]
Two-tuple, dimensions of the WCS footprint
center_coords : tuple[int]
Two-tuple of on-sky coordinates of the center of the WCS in
decimal degrees, in ICRS.
rotation : float, optional
Rotation in degrees, from ICRS equator. In decimal degrees.
pixscale : float
Pixel scale in arcsec/pixel.
Returns
-------
wcs : `astropy.wcs.WCS`
The world coordinate system.
Examples
--------
>>> from kbmod.mocking import make_wcs
>>> shape = (100, 100)
>>> wcs = make_wcs(shape)
>>> wcs = make_wcs(shape, (115, 5), 45, 0.1)
"""
wcs = WCS(naxis=2)
rho = rotation*0.0174533 # deg to rad
scale = 0.1 / 3600.0 # arcsec/pixel to deg/pix

if shape is not None:
wcs.pixel_shape = shape
wcs.wcs.crpix = [shape[1] / 2, shape[0] / 2]
else:
wcs.wcs.crpix = [0, 0]
wcs.wcs.crval = center_coords
wcs.wcs.cunit = ['deg', 'deg']
wcs.wcs.cd = [[-scale * np.cos(rho), scale * np.sin(rho)],
[scale * np.sin(rho), scale * np.cos(rho)]]
wcs.wcs.radesys = 'ICRS'
wcs.wcs.ctype = ['RA---TAN', 'DEC--TAN']

return wcs


class HeaderFactory:
primary_template = {
"EXTNAME": "PRIMARY",
Expand All @@ -29,14 +78,6 @@ class HeaderFactory:

ext_template = {"NAXIS": 2, "NAXIS1": 2048, "NAXIS2": 4096, "CRPIX1": 1024, "CPRIX2": 2048, "BITPIX": 32}

wcs_template = {
"ctype": ["RA---TAN", "DEC--TAN"],
"crval": [351, -5],
"cunit": ["deg", "deg"],
"radesys": "ICRS",
"cd": [[-1.44e-07, 7.32e-05], [7.32e-05, 1.44e-05]],
}

def __validate_mutables(self):
# !xor
if bool(self.mutables) != bool(self.callbacks):
Expand Down Expand Up @@ -87,23 +128,16 @@ def mock(self, n=1):
return headers

@classmethod
def gen_wcs(cls, metadata):
wcs = WCS(naxis=2)
for k, v in metadata.items():
setattr(wcs.wcs, k, v)
return wcs.to_header()

@classmethod
def gen_header(cls, base, overrides, wcs_base=None):
def gen_header(cls, base, overrides, wcs=None):
header = Header(base)
header.update(overrides)

if wcs_base is not None:
naxis1 = header.get("NAXIS1", False)
naxis2 = header.get("NAXIS2", False)
if not all((naxis1, naxis2)):
raise ValueError("Adding a WCS to the header requires " "NAXIS1 and NAXIS2 keys.")
header.update(cls.gen_wcs(wcs_base))
if wcs is not None:
# Sync WCS with header + overwrites
wcs_header = wcs.to_header()
wcs_header.update(header)
# then merge back to mocked header template
header.update(wcs_header)

return header

Expand All @@ -122,9 +156,13 @@ def from_ext_template(cls, overrides=None, mutables=None, callbacks=None, wcs=No
ext_template["CRPIX1"] = shape[0] // 2
ext_template["CRPIX2"] = shape[1] // 2

hdr = cls.gen_header(
base=ext_template, overrides=overrides, wcs_base=cls.wcs_template if wcs is None else wcs
)
if wcs is None:
wcs = make_wcs(
shape=(ext_template["NAXIS1"], ext_template["NAXIS2"]),

)

hdr = cls.gen_header(base=ext_template, overrides=overrides, wcs=wcs)
return cls(hdr, mutables, callbacks)


Expand Down
32 changes: 26 additions & 6 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ def test_static_objects(self):
self.assertTrue(len(results) == 0)


class TestLinearSearch(unittest.TestCase):
class TestRandomLinearSearch(unittest.TestCase):
def setUp(self):
# Set up shared search values
self.n_imgs = 10
self.repeat_n_times = 10
self.shape = (500, 500)
self.start_pos = (10, 50)
self.vxs = [10, 30]
self.vys = [10, 30]
self.shape = (300, 300)
self.start_pos = (125, 175)
self.vxs = [-10, 10]
self.vys = [-10, 10]

# Set up configs for mocking and search
# These don't change from test to test
Expand Down Expand Up @@ -100,7 +100,7 @@ def setUp(self):
}
)

def test_search(self):
def test_simple_search(self):
# Mock the data and repeat tests. The random catalog
# creation guarantees a diverse set of changing test values
for i in range(self.repeat_n_times):
Expand All @@ -123,6 +123,26 @@ def test_search(self):
self.assertLessEqual(abs(obj["vx"] - res["vx"]), 5)
self.assertLessEqual(abs(obj["vy"] - res["vy"]), 5)

def test_diffim_mocks(self):
src_cat = kbmock.SourceCatalog.from_defaults()
obj_cat = kbmock.ObjectCatalog.from_defaults(self.param_ranges, n=1)
factory = kbmock.DECamImdiff.from_defaults(with_data=True, src_cat=src_cat, obj_cat=obj_cat)
hduls = factory.mock(n=self.n_imgs)

ic = ImageCollection.fromTargets(hduls, force="TestDataStd")
wu = ic.toWorkUnit(search_config=self.config)
results = SearchRunner().run_search_from_work_unit(wu)

# Run tests
self.assertGreaterEqual(len(results), 1)
for res in results:
diff = abs(obj_cat.table["y_mean"] - res["y"])
obj = obj_cat.table[diff == diff.min()]
self.assertLessEqual(abs(obj["x_mean"] - res["x"]), 5)
self.assertLessEqual(abs(obj["y_mean"] - res["y"]), 5)
self.assertLessEqual(abs(obj["vx"] - res["vx"]), 5)
self.assertLessEqual(abs(obj["vy"] - res["vy"]), 5)


####

Expand Down

0 comments on commit 8c9ae11

Please sign in to comment.