diff --git a/src/kbmod/mocking/callbacks.py b/src/kbmod/mocking/callbacks.py index 4b2b15750..28c1d9397 100644 --- a/src/kbmod/mocking/callbacks.py +++ b/src/kbmod/mocking/callbacks.py @@ -1,3 +1,8 @@ +import random + +from astropy.time import Time +import astropy.units as u + __all__ = [ "IncrementObstime", "ObstimeIterator", @@ -5,20 +10,33 @@ class IncrementObstime: + default_unit = "day" def __init__(self, start, dt): - self.start = start + self.start = Time(start) + if not isinstance(dt, u.Quantity): + dt = dt * getattr(u, self.default_unit) self.dt = dt - def __call__(self, mut_val): + def __call__(self, header_val): curr = self.start self.start += self.dt - return curr + return curr.fits class ObstimeIterator: - def __init__(self, obstimes): - self.obstimes = obstimes + def __init__(self, obstimes, **kwargs): + self.obstimes = Time(obstimes, **kwargs) self.generator = (t for t in obstimes) - def __call__(self, mut_val): - return next(self.generator) + def __call__(self, header_val): + return Time(next(self.generator)).fits + + +class DitherValue: + def __init__(self, value, dither_range): + self.value = value + self.dither_range = dither_range + + def __call__(self, header_val): + return self.value + random.uniform(self.dither_range) + diff --git a/src/kbmod/mocking/catalogs.py b/src/kbmod/mocking/catalogs.py index f85a9fce3..3063116d9 100644 --- a/src/kbmod/mocking/catalogs.py +++ b/src/kbmod/mocking/catalogs.py @@ -2,11 +2,13 @@ import numpy as np from astropy.table import QTable +from astropy.coordinates import SkyCoord + from .config import Config __all__ = [ - "gen_catalog", + "gen_random_catalog", "CatalogFactory", "SimpleCatalog", "SourceCatalogConfig", @@ -16,7 +18,7 @@ ] -def gen_catalog(n, param_ranges, seed=None): +def gen_random_catalog(n, param_ranges, seed=None): cat = QTable() rng = np.random.default_rng(seed) @@ -30,9 +32,6 @@ def gen_catalog(n, param_ranges, seed=None): # conversion assumes a gaussian if "flux" in param_ranges and "amplitude" not in param_ranges: - xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1.0 - ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1.0 - cat["amplitude"] = cat["flux"] / (2.0 * np.pi * xstd * ystd) return cat @@ -45,7 +44,8 @@ def mock(self, *args, **kwargs): class SimpleCatalogConfig(Config): - mode = "static" + mode = "static" # folding + kind = "pixel" # world return_copy = False seed = None n = 100 @@ -56,28 +56,32 @@ class SimpleCatalog(CatalogFactory): default_config = SimpleCatalogConfig def __init__(self, config, table, **kwargs): - config = self.default_config(**kwargs) - self.config = config + self.config = self.default_config(config=config, **kwargs) self.table = table self.current = 0 @classmethod def from_config(cls, config, **kwargs): config = cls.default_config(config=config, **kwargs) - table = gen_catalog(config.n, config.param_ranges, config.seed) + table = gen_random_catalog(config["n"], config["param_ranges"], config["seed"]) return cls(config, table) @classmethod def from_defaults(cls, param_ranges=None, **kwargs): config = cls.default_config(**kwargs) if param_ranges is not None: - config.param_ranges.update(param_ranges) + config["param_ranges"].update(param_ranges) return cls.from_config(config) @classmethod def from_table(cls, table, **kwargs): + if "x_stddev" not in table.columns: + table["x_stddev"] = table["stddev"] + if "y_stddev" not in table.columns: + table["y_stddev"] = table["stddev"] + config = cls.default_config(**kwargs) - config.n = len(table) + config["n"] = len(table) params = {} for col in table.keys(): params[col] = (table[col].min(), table[col].max()) @@ -86,7 +90,7 @@ def from_table(cls, table, **kwargs): def mock(self): self.current += 1 - if self.config.return_copy: + if self.config["return_copy:"]: return self.table.copy() return self.table @@ -127,7 +131,7 @@ def __init__(self, config, table, **kwargs): kwargs["return_copy"] = True super().__init__(config, table, **kwargs) self._realization = self.table.copy() - self.mode = self.config.mode + self.mode = self.config["mode"] @property def mode(self): @@ -137,8 +141,10 @@ def mode(self): def mode(self, val): if val == "folding": self._gen_realization = self.fold - elif val == "progressive": - self._gen_realization = self.next + elif val == "progressive" and self.config["kind"] == "pixel": + self._gen_realization = self.next_pixel + elif val == "progressive" and self.config["kind"] == "world": + self._gen_realization = self.next_world elif val == "static": self._gen_realization = self.static else: @@ -155,25 +161,38 @@ def reset(self): def static(self, **kwargs): return self.table.copy() - def next(self, dt, **kwargs): - self._realization["x_mean"] = self.table["x_mean"] + self.current * self._realization["vx"] * dt - self._realization["y_mean"] = self.table["y_mean"] + self.current * self._realization["vy"] * dt + def _next(self, dt, keys): + a, va, b, vb = keys + self._realization[a] = self.table[a] + self.current * self.table[va] * dt + self._realization[b] = self.table[b] + self.current * self.table[vb] * dt self.current += 1 return self._realization.copy() + def next_world(self, dt): + return self._next(dt, ["ra_mean", "v_ra", "dec_mean", "v_dec"]) + + def next_pixel(self, dt): + return self._next(dt, ["x_mean", "vx", "y_mean", "vy"]) + def fold(self, t, **kwargs): self._realization = self.table[self.table["obstime"] == t] self.current += 1 return self._realization.copy() - def mock(self, n=1, **kwargs): + def mock(self, n=1, dt=None, t=None, wcs=None): data = [] if self.mode == "folding": - for t in kwargs["t"]: - data.append(self.fold(t=t)) + for i, ts in enumerate(t): + data.append(self.fold(t=ts)) else: for i in range(n): - data.append(self._gen_realization(**kwargs)) + data.append(self._gen_realization(dt)) + + if self.config["kind"] == "world": + for cat, w in zip(data, wcs): + x, y = w.world_to_pixel(SkyCoord(ra=cat["ra_mean"], dec=cat["dec_mean"], unit="deg")) + cat["x_mean"] = x + cat["y_mean"] = y return data diff --git a/src/kbmod/mocking/config.py b/src/kbmod/mocking/config.py index b0ad888c5..7af9b97f9 100644 --- a/src/kbmod/mocking/config.py +++ b/src/kbmod/mocking/config.py @@ -8,13 +8,10 @@ class ConfigurationError(Exception): class Config: - """Base configuration class. + """Base class for Standardizer configuration. - Config classes that inherit from this class define configuration as their - class attributes. Particular attributes can be overriden on an per-instance - basis by providing a config overrides at initialization time. - - Configs inheriting from this config support basic dictionary operations. + Not all standardizers will (can) use the same parameters so refer to their + respective documentation for a more complete list. Parameters ---------- @@ -24,66 +21,33 @@ class attributes. Particular attributes can be overriden on an per-instance Keyword arguments, assigned as configuration key-values. """ - def __init__(self, config=None, method="default", **kwargs): + def __init__(self, config=None, **kwargs): # This is a bit hacky, but it makes life a lot easier because it # enables automatic loading of the default configuration and separation # of default config from instance bound config keys = list(set(dir(self.__class__)) - set(dir(Config))) # First fill out all the defaults by copying cls attrs - self._conf = {k: copy.copy(getattr(self, k)) for k in keys} + self._conf = {k: getattr(self, k) for k in keys} # Then override with any user-specified values - self.update(config=config, method=method, **kwargs) - - @classmethod - def from_configs(cls, *args): - config = cls() - for conf in args: - config.update(config=conf, method="extend") - return config + if config is not None: + self._conf.update(config) + self._conf.update(kwargs) + # now just shortcut the most common dict operations def __getitem__(self, key): return self._conf[key] - # now just shortcut the most common dict operations - def __getattribute__(self, key): - hasconf = "_conf" in object.__getattribute__(self, "__dict__") - if hasconf: - conf = object.__getattribute__(self, "_conf") - if key in conf: - return conf[key] - return object.__getattribute__(self, key) - def __setitem__(self, key, value): self._conf[key] = value - def __repr__(self): - res = f"{self.__class__.__name__}(" - for k, v in self.items(): - res += f"{k}: {v}, " - return res[:-2] + ")" - def __str__(self): res = f"{self.__class__.__name__}(" for k, v in self.items(): res += f"{k}: {v}, " return res[:-2] + ")" - def _repr_html_(self): - repr = f""" - - - - - - - """ - for k, v in self.items(): - repr += f"
{self.__class__.__name__}
KeyValue
{k}{v}\n" - repr += "
" - return repr - def __len__(self): return len(self._conf) @@ -107,7 +71,7 @@ def __or__(self, other): elif isinstance(other, dict): return self.__class__(config=self._conf | other) else: - raise TypeError("unsupported operand type(s) for |: {type(self)}and {type(other)}") + raise TypeError("unsupported operand type(s) for |: {type(self)} " "and {type(other)}") def keys(self): """A set-like object providing a view on config's keys.""" @@ -121,68 +85,22 @@ def items(self): """A set-like object providing a view on config's items.""" return self._conf.items() - def copy(self): - return self.__class__(config=self._conf.copy()) - - def update(self, config=None, method="default", **kwargs): - """Update this config from dict/other config/iterable and - apply any explicit keyword overrides. - - A dict-like update. If ``conf`` is given and has a ``.keys()`` - method, performs: - - for k in conf: this[k] = conf[k] - - If ``conf`` is given but lacks a ``.keys()`` method, performs: + def update(self, conf=None, **kwargs): + """Update this config from dict/other config/iterable. - for k, v in conf: this[k] = v + A dict-like update. If ``conf`` is present and has a ``.keys()`` + method, then does: ``for k in conf: this[k] = conf[k]``. If ``conf`` + is present but lacks a ``.keys()`` method, then does: + ``for k, v in conf: this[k] = v``. - In both cases, explicit overrides are applied at the end: - - for k in kwargs: this[k] = kwargs[k] + In either case, this is followed by: + ``for k in kwargs: this[k] = kwargs[k]`` """ - # Python < 3.9 does not support set operations for dicts - # [fixme]: Update this to: other = conf | kwargs - # and remove current implementation when 3.9 gets too old. Order of - # conf and kwargs matter to correctly apply explicit overrides - - # Check if both conf and kwargs are given, just conf or just - # kwargs. If none are given do nothing to comply with default - # dict behavior - if config is not None and kwargs: - other = {**config, **kwargs} - elif config is not None: - other = config - elif kwargs is not None: - other = kwargs - else: - return - - # then, see if we the given config and overrides are a subset of this - # config or it's superset. Depending on the selected method then raise - # errors, ignore or extend the current config if the given config is a - # superset (or disjoint) from the current one. - subset = {k: v for k, v in other.items() if k in self._conf} - superset = {k: v for k, v in other.items() if k not in subset} - - if method.lower() == "default": - if superset: - raise ConfigurationError( - "Tried setting the following fields, not a part of " - f"this configuration options: {superset}" - ) - conf = other # == subset - elif method.lower() == "subset": - conf = subset - elif method.lower() == "extend": - conf = other - else: - raise ValueError( - "Method expected to be one of 'default', " f"'subset' or 'extend'. Got {method} instead." - ) - - self._conf.update(conf) + if conf is not None: + self._conf.update(conf) + self._conf.update(kwargs) def toDict(self): """Return this config as a dict.""" return self._conf + diff --git a/src/kbmod/mocking/fits.py b/src/kbmod/mocking/fits.py index d10608ae7..5db6650ad 100644 --- a/src/kbmod/mocking/fits.py +++ b/src/kbmod/mocking/fits.py @@ -1,4 +1,5 @@ from astropy.io.fits import HDUList, PrimaryHDU, CompImageHDU, BinTableHDU +from astropy.wcs import WCS from .callbacks import IncrementObstime from .headers import HeaderFactory, ArchivedHeader @@ -25,57 +26,32 @@ class NoneFactory: "Kinda makes some code later prettier. Kinda" - def mock(self, n): return [ None, ] * n -def hdu_cast(hdu_cls, hdr, data=None, validate_header=False, update_header=False): - hdu = hdu_cls() - - if validate_header: - hdu.header.update(hdr) - else: - hdu.header = hdr - - if data is not None: - hdu.data = data - if update_header: - hdu.update_header() - - return hdu - - -def hdu_cast_array(hdu_cls, hdr, data, validate_header=False, update_header=False): - hdus = [] - for hdr, dat in zip(hdr, data): - hdus.append(hdu_cast(hdu_cls, hdr, dat)) - return hdus - - class EmptyFits: def __init__( self, header=None, shape=(100, 100), - start_mjd=60310, - step_mjd=0.001, + start_t="2024-01-01T00:00:00.00", + step_t=0.001, editable_images=False, editable_masks=False, ): self.prim_hdr = HeaderFactory.from_primary_template( - overrides=header, mutables=["OBS-MJD"], callbacks=[IncrementObstime(start=start_mjd, dt=step_mjd)] + overrides=header, + mutables=["DATE-OBS"], + callbacks=[IncrementObstime(start=start_t, dt=step_t)] ) self.img_hdr = HeaderFactory.from_ext_template({"EXTNAME": "IMAGE"}, shape=shape) self.var_hdr = HeaderFactory.from_ext_template({"EXTNAME": "VARIANCE"}, shape=shape) self.mask_hdr = HeaderFactory.from_ext_template({"EXTNAME": "MASK"}, shape=shape) - # 2.2) Then data factories, attempt to save performance and memory - # where possible by really only allocating 1 array whenever the - # data is read-only and content-static between created HDUs. self.img_data = DataFactory.from_header( kind="image", header=self.img_hdr.header, writeable=editable_images, return_copy=editable_images ) @@ -86,24 +62,26 @@ def __init__( self.current = 0 def mock(self, n=1): - img_hdr = self.img_hdr.mock()[0] - var_hdr = self.var_hdr.mock()[0] - mask_hdr = self.mask_hdr.mock()[0] + prim_hdrs = self.prim_hdr.mock(n=n) + img_hdrs = self.img_hdr.mock(n=n) + var_hdrs = self.var_hdr.mock(n=n) + mask_hdrs = self.mask_hdr.mock(n=n) + images = self.img_data.mock(n=n) variances = self.img_data.mock(n=n) masks = self.mask_data.mock(n=n) hduls = [] - for i in range(n): + for ph, ih, vh, mh, imd, vd, md in zip( + prim_hdrs, img_hdrs, var_hdrs, mask_hdrs, images, variances, masks + ): hduls.append( - HDUList( - hdus=[ - PrimaryHDU(header=self.prim_hdr.mock()[0]), - CompImageHDU(header=img_hdr, data=images[i]), - CompImageHDU(header=var_hdr, data=variances[i]), - CompImageHDU(header=mask_hdr, data=masks[i]), - ] - ) + HDUList(hdus=[ + PrimaryHDU(header=ph), + CompImageHDU(header=ih, data=imd), + CompImageHDU(header=vh, data=vd), + CompImageHDU(header=mh, data=md) + ]) ) self.current += n @@ -112,25 +90,48 @@ def mock(self, n=1): class SimpleFits: def __init__( - self, - header=None, - shape=(100, 100), - start_mjd=60310, - step_mjd=0.001, - with_noise=False, - noise="simplistic", - src_cat=None, - obj_cat=None, + self, + shared_header_metadata=None, + shape=(100, 100), + start_t="2024-01-01T00:00:00.00", + step_t=0.001, + with_noise=False, + noise="simplistic", + src_cat=None, + obj_cat=None, + wcs_factory=None, ): # 2. Set up Header and Data factories that go into creating HDUs # 2.1) First headers, since that metadata specified data formats self.prim_hdr = HeaderFactory.from_primary_template( - overrides=header, mutables=["OBS-MJD"], callbacks=[IncrementObstime(start=start_mjd, dt=step_mjd)] + overrides=shared_header_metadata, + mutables=["DATE-OBS"], callbacks=[IncrementObstime(start=start_t, dt=step_t)] ) - self.img_hdr = HeaderFactory.from_ext_template({"EXTNAME": "IMAGE"}, shape=shape) - self.var_hdr = HeaderFactory.from_ext_template({"EXTNAME": "VARIANCE"}, shape=shape) - self.mask_hdr = HeaderFactory.from_ext_template({"EXTNAME": "MASK"}, shape=shape) + wcs = None + if wcs_factory is not None: + wcs = wcs_factory + + if shared_header_metadata is None: + shared_header_metadata = {"EXTNAME": "IMAGE"} + + self.img_hdr = HeaderFactory.from_ext_template( + overrides=shared_header_metadata.copy(), + shape=shape, + wcs=wcs + ) + shared_header_metadata["EXTNAME"] = "VARIANCE" + self.var_hdr = HeaderFactory.from_ext_template( + overrides=shared_header_metadata.copy(), + shape=shape, + wcs=wcs + ) + shared_header_metadata["EXTNAME"] = "MASK" + self.mask_hdr = HeaderFactory.from_ext_template( + overrides=shared_header_metadata.copy(), + shape=shape, + wcs=wcs + ) # 2.2) Then data factories if noise == "realistic": @@ -140,8 +141,8 @@ def __init__( self.var_data = SimpleVariance(self.img_data.base) self.mask_data = SimpleMask.from_image(self.img_data.base) - self.start_mjd = start_mjd - self.step_mjd = step_mjd + self.start_t = start_t + self.step_t = step_t self.obj_cat = obj_cat self.current = 0 @@ -153,8 +154,12 @@ def mock(self, n=1): obj_cats = None if self.obj_cat is not None: - kwargs = {"dt": self.step_mjd, "t": [hdr["OBS-MJD"] for hdr in prim_hdrs]} - obj_cats = self.obj_cat.mock(n=n, **kwargs) + obj_cats = self.obj_cat.mock( + n=n, + dt=self.step_t, + t=[hdr["DATE-OBS"] for hdr in prim_hdrs], + wcs=[WCS(hdr) for hdr in img_hdrs] + ) images = self.img_data.mock(n, obj_cats=obj_cats) variances = self.var_data.mock(images=images) @@ -165,14 +170,12 @@ def mock(self, n=1): prim_hdrs, img_hdrs, var_hdrs, mask_hdrs, images, variances, masks ): hduls.append( - HDUList( - hdus=[ - PrimaryHDU(header=ph), - CompImageHDU(header=ih, data=imd), - CompImageHDU(header=vh, data=vd), - CompImageHDU(header=mh, data=md), - ] - ) + HDUList(hdus=[ + PrimaryHDU(header=ph), + CompImageHDU(header=ih, data=imd), + CompImageHDU(header=vh, data=vd), + CompImageHDU(header=mh, data=md) + ]) ) self.current += n @@ -180,79 +183,78 @@ def mock(self, n=1): class DECamImdiff: - @classmethod - def from_defaults( - cls, - with_data=False, - override_original=False, - shape=(100, 100), - start_mjd=60310, - step_mjd=0.001, - with_noise=False, - noise="simplistic", - src_cat=None, - obj_cat=None - ): - if obj_cat.config.type == "progressive": + def __init__(self, with_data=False, with_noise=False, noise="simplistic", + src_cat=None, obj_cat=None): + if obj_cat is not None and obj_cat.mode == "progressive": raise ValueError( "Only folding or static object catalogs can be used with" "default archived DECam headers since header timestamps are not " "required to be equally spaced." ) - hdr_factory = ArchivedHeader("headers_archive.tar.bz2", "decam_imdiff_headers.ecsv") - - hdu_types = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU] - hdu_types.extend([BinTableHDU] * 12) - data = [NoneFactory()] * 16 + self.hdr_factory = ArchivedHeader("headers_archive.tar.bz2", "decam_imdiff_headers.ecsv") + self.data_factories = [NoneFactory()] * 16 if with_data: - headers = hdr_factory.get(0) + headers = self.hdr_factory.get(0) shape = (headers[1]["NAXIS1"], headers[1]["NAXIS2"]) dtype = DataFactoryConfig.bitpix_type_map[headers[1]["BITPIX"]] - # 2.1) Now we can instantiate data factories with correct configs - # and fill in the data placeholder + # Read noise and gain are typical values. DECam has 2 amps per CCD, + # each powering ~half of the plane. Their values and areas are + # recorded in the header, but that would mean I would have to + # produce an image which has different zero-offsets for the two + # halves which is too much detail for this use-case. Typical values + # are taken from the DECam Data Handbook Version 2.05 March 2014 + # Table 2.2 if noise == "realistic": - img_data = SimulatedImage(src_cat=src_cat, shape=shape, dtype=dtype) + self.img_data = SimulatedImage(src_cat=src_cat, shape=shape, dtype=dtype) else: - img_data = SimpleImage(src_cat=src_cat, shape=shape, dtype=dtype) - var_data = SimpleVariance(img_data.base) - mask_data = SimpleMask.from_image(img_data.base) - - data = [NoneFactory(), img_data, var_data, mask_data] - data.extend([DataFactory.from_header("table", h) for h in headers[4:]]) + self.img_data = SimpleImage(src_cat=src_cat, shape=shape, dtype=dtype) + self.var_data = SimpleVariance(self.img_data.base, read_noise=7.0, gain=4.0) + self.mask_data = SimpleMask.from_image(self.img_data.base) - return cls(hdr_factory, data_factories=data, obj_cat=obj_cat) + self.data_factories[1] = self.img_data + self.data_factories[2] = self.mask_data + self.data_factories[3] = self.mask_data + self.data_factories[4:] = [DataFactory.from_header("table", h) for h in headers[4:]] - def __init__(self, header_factory, data_factories=None, obj_cat=None): - self.hdr_factory = header_factory - self.data_factories = data_factories + self.with_data = with_data + self.src_cat = src_cat + self.obj_cat = obj_cat self.hdu_layout = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU] self.hdu_layout.extend([BinTableHDU] * 12) + self.current = 0 def mock(self, n=1): + headers = self.hdr_factory.mock(n=n) + obj_cats = None if self.obj_cat is not None: - obj_cats = self.obj_cat.mock(n, dt=self.config.dt) - - hdrs = self.hdr_factory.mock(n) + kwargs = {"t": [hdrs[0][0]["DATE-AVG"] for hdr in hdrs]} + obj_cats = self.obj_cat.mock(n=n, **kwargs) - if self.data_factories is not None: - images = self.img_data.mock(n, obj_cats=obj_cats) + if self.with_data: + images = self.img_data.mock(n=n, obj_cats=obj_cats) + masks = self.mask_data.mock(n=n) variances = self.var_data.mock(images=images) - data = [self.data[0].mock(n), images, variances] - for factory in self.data[3:]: - data.append(factory.mock(n=n)) + data = [ + NoneFactory().mock(n=n), + images, + masks, + variances + ] + data.extend([factory.mock(n=n) for factory in self.data_factories[4:]]) else: - data = [f.mock(n=n) for f in self.data] + data = [factory.mock(n=n) for factory in self.data_factories] hduls = [] - for hdul_idx in range(n): + for i, hdrs in enumerate(headers): hdus = [] - for hdu_idx, hdu_cls in enumerate(self.hdu_types): - hdus.append(self.hdu_cast(hdu_cls, hdrs[hdul_idx][hdu_idx], data[hdu_idx][hdul_idx])) + for j, (layer, hdr) in enumerate(zip(self.hdu_layout, hdrs)): + hdus.append(layer(header=hdr, data=data[j][i])) hduls.append(HDUList(hdus=hdus)) + self.current += n return hduls diff --git a/src/kbmod/mocking/fits_data.py b/src/kbmod/mocking/fits_data.py index e11ff8437..669c26917 100644 --- a/src/kbmod/mocking/fits_data.py +++ b/src/kbmod/mocking/fits_data.py @@ -61,7 +61,7 @@ def add_model_objects(img, catalog, model): setattr(model, param, source[param]) if all( - [model.x_mean > 0, model.x_mean < img.shape[1], model.y_mean > 0, model.y_mean < img.shape[0]] + [model.x_mean > 0, model.x_mean < img.shape[1], model.y_mean > 0, model.y_mean < img.shape[0]] ): model.render(img) finally: @@ -157,8 +157,8 @@ class DataFactory: default_config = DataFactoryConfig """Default configuration.""" - def __init__(self, base, config=None, **kwargs): - self.config = self.default_config(config, **kwargs) + def __init__(self, base, **kwargs): + self.config = self.default_config(**kwargs) self.base = base if base is None: @@ -167,12 +167,12 @@ def __init__(self, base, config=None, **kwargs): else: self.shape = base.shape self.dtype = base.dtype - self.base.flags.writeable = self.config.writeable - self.counter = 0 + self.base.flags.writeable = self.config["writeable"] + self.counter = 0 @classmethod - def gen_image(cls, metadata=None, config=None, **kwargs): - conf = cls.default_config(config, method="subset", **kwargs) + def gen_image(cls, metadata=None, **kwargs): + conf = cls.default_config(**kwargs) cols = metadata.get("NAXIS1", conf.default_img_shape[0]) rows = metadata.get("NAXIS2", conf.default_img_shape[1]) bitwidth = metadata.get("BITPIX", conf.default_img_bit_width) @@ -247,11 +247,11 @@ def mock(self, n=1, **kwargs): "Use `zeros` or `from_hdu` to construct this object correctly." ) - if self.config.return_copy: + if self.config["return_copy"]: base = np.repeat(self.base[np.newaxis,], (n,), axis=0) else: base = np.broadcast_to(self.base, (n, *self.shape)) - base.flags.writeable = self.config.writeable + base.flags.writeable = self.config["writeable"] return base @@ -288,18 +288,18 @@ class SimpleVariance(DataFactory): default_config = SimpleVarianceConfig - def __init__(self, image=None, config=None, **kwargs): + def __init__(self, image=None, **kwargs): # skip setting the base here since the real base is # derived from given image we just set it manually - super().__init__(base=None, config=config, **kwargs) + super().__init__(base=None, **kwargs) if image is not None: - self.base = image / self.config.gain + self.config.read_noise**2 + self.base = image / self.config["gain"] + self.config["read_noise"]**2 def mock(self, images=None): if images is None: return self.base - return images / self.config.gain + self.config.read_noise**2 + return images / self.config["gain"] + self.config["read_noise"]**2 class SimpleMaskConfig(DataFactoryConfig): @@ -330,18 +330,18 @@ class SimpleMask(DataFactory): default_config = SimpleMaskConfig - def __init__(self, mask, config=None, **kwargs): - super().__init__(base=mask, config=config, **kwargs) + def __init__(self, mask, **kwargs): + super().__init__(base=mask, **kwargs) @classmethod - def from_image(cls, image, config=None, **kwargs): - config = cls.default_config(config=config, **kwargs, method="subset") + def from_image(cls, image, **kwargs): + config = cls.default_config(**kwargs) mask = image.copy() - mask[image > config.threshold] = 1 + mask[image > config["threshold"]] = 1 return cls(mask) @classmethod - def from_params(cls, config=None, **kwargs): + def from_params(cls, **kwargs): """Create a mask by adding a padding around the edges of the array with the given dimensions and mask out bad columns. @@ -387,10 +387,10 @@ def from_params(cls, config=None, **kwargs): [1., 0., 1., 1., 0., 0., 0., 0., 0., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]) """ - config = cls.default_config(config=config, **kwargs, method="subset") - mask = np.zeros(config.shape, dtype=config.dtype) + config = cls.default_config(**kwargs) + mask = np.zeros(config["shape"], dtype=config["dtype"]) - shape, padding = config.shape, config.padding + shape, padding = config["shape"], config["padding"] # padding mask[:padding] = 1 @@ -399,10 +399,10 @@ def from_params(cls, config=None, **kwargs): mask[: shape[1] - padding :] = 1 # bad columns - for col in config.bad_columns: + for col in config["bad_columns"]: mask[:, col] = 1 - for patch, value in config.patches: + for patch, value in config["patches"]: if isinstance(patch, tuple): mask[patch] = 1 elif isinstance(slice): @@ -410,7 +410,7 @@ def from_params(cls, config=None, **kwargs): else: raise ValueError(f"Expected a tuple (x, y), (slice, slice) or slice, got {patch} instead.") - return cls(mask, config=config) + return cls(mask, **config) class SimpleImageConfig(DataFactoryConfig): @@ -464,16 +464,15 @@ class SimpleImage(DataFactory): default_config = SimpleImageConfig - def __init__(self, image=None, src_cat=None, obj_cat=None, config=None, + def __init__(self, image=None, src_cat=None, obj_cat=None, dtype=np.float32, **kwargs): - self.config = self.default_config(config=config, **kwargs) - super().__init__(image, self.config, **kwargs) + super().__init__(image, **kwargs) if image is None: - image = np.zeros(self.config.shape, dtype=dtype) + image = np.zeros(self.config["shape"], dtype=dtype) else: image = image - self.config.shape = image.shape + self.config["shape"] = image.shape # Astropy throws a strange ValueError instead of reporting a non-writeable # array, This must be a bug TODO: report. It's not safe to edit a @@ -481,8 +480,8 @@ def __init__(self, image=None, src_cat=None, obj_cat=None, config=None, self.src_cat = src_cat if self.src_cat is not None: image = image if image.flags.writeable else image.copy() - add_model_objects(image, src_cat.table, self.config.model(x_stddev=1, y_stddev=1)) - image.flags.writeable = self.config.writeable + add_model_objects(image, src_cat.table, self.config["model"](x_stddev=1, y_stddev=1)) + image.flags.writeable = self.config["writeable"] self.base = image self._base_contains_data = image.sum() != 0 @@ -498,16 +497,16 @@ def add_noise(cls, images, config): config : `SimpleImageConfig` Configuration. """ - rng = np.random.default_rng(seed=config.seed) + rng = np.random.default_rng(seed=config["seed"]) shape = images.shape # noise has to be resampled for every image rng.standard_normal(size=shape, dtype=images.dtype, out=images) # There's a lot of multiplications that happen, skip if possible - if config.noise_std != 1.0: - images *= config.noise_std - images += config.noise + if config["noise_std"] != 1.0: + images *= config["noise_std"] + images += config["noise"] return images @@ -522,10 +521,10 @@ def mock(self, n=1, obj_cats=None, **kwargs): A list of catalogs as long as the number of requested images of moving objects that will be inserted into the image. """ - shape = (n, *self.config.shape) + shape = (n, *self.config["shape"]) images = np.zeros(shape, dtype=np.float32) - if self.config.add_noise: + if self.config["add_noise"]: images = self.add_noise(images=images, config=self.config) # if base has no data (no sources, bad cols etc) skip @@ -540,7 +539,7 @@ def mock(self, n=1, obj_cats=None, **kwargs): if obj_cats is not None: pairs = [(images[0], obj_cats[0])] if n == 1 else zip(images, obj_cats) for i, (img, cat) in enumerate(pairs): - add_model_objects(img, cat, self.config.model(x_stddev=1, y_stddev=1)) + add_model_objects(img, cat, self.config["model"](x_stddev=1, y_stddev=1)) return images @@ -701,24 +700,24 @@ def add_bad_cols(cls, image, config): image : `np.array` Image. """ - if not config.add_bad_columns: + if not config["add_bad_columns"]: return image shape = image.shape - rng = np.random.RandomState(seed=config.bad_cols_seed) - if config.bad_cols_method == "random": - bad_cols = rng.randint(0, shape[1], size=config.n_bad_cols) - elif config.bad_col_locs: - bad_cols = config.bad_col_locs + rng = np.random.RandomState(seed=config["bad_cols_seed"]) + if config["bad_cols_method"] == "random": + bad_cols = rng.randint(0, shape[1], size=config["n_bad_cols"]) + elif config["bad_col_locs"]: + bad_cols = config["bad_col_locs"] else: raise ConfigurationError( - "Bad columns method is not 'random', but `bad_col_locs` " "contains no column indices." + "Bad columns method is not 'random', but `bad_col_locs` contains no column indices." ) - col_pattern = rng.randint(low=0, high=int(config.bad_col_pattern_offset), size=shape[0]) + col_pattern = rng.randint(low=0, high=int(config["bad_col_pattern_offset"]), size=shape[0]) for col in bad_cols: - image[:, col] += col_pattern + config.bad_col_offset + image[:, col] += col_pattern + config["bad_col_offset"] return image @@ -741,16 +740,16 @@ def add_hot_pixels(cls, image, config): image : `np.array` Image. """ - if not config.add_hot_pixels: + if not config["add_hot_pixels"]: return image shape = image.shape - if config.hot_pix_method == "random": - rng = np.random.RandomState(seed=config.hot_pix_seed) - x = rng.randint(0, shape[1], size=config.n_hot_pix) - y = rng.randint(0, shape[0], size=config.n_hot_pix) + if config["hot_pix_method"] == "random": + rng = np.random.RandomState(seed=config["hot_pix_seed"]) + x = rng.randint(0, shape[1], size=config["n_hot_pix"]) + y = rng.randint(0, shape[0], size=config["n_hot_pix"]) hot_pixels = np.column_stack([x, y]) - elif config.hot_pix_locs: + elif config["hot_pix_locs"]: hot_pixels = pixels else: raise ConfigurationError( @@ -759,7 +758,7 @@ def add_hot_pixels(cls, image, config): ) for pix in hot_pixels: - image[*pix] += config.hot_pix_offset + image[*pix] += config["hot_pix_offset"] return image @@ -783,14 +782,14 @@ def add_noise(cls, images, config): shape = images.shape # add read noise - images += config.read_noise_gen(scale=config.read_noise / config.gain, size=shape) + images += config["read_noise_gen"](scale=config["read_noise"] / config["gain"], size=shape) # add dark current - current = config.dark_current * config.exposure_time / config.gain - images += config.dark_current_gen(current, size=shape) + current = config["dark_current"] * config["exposure_time"] / config["gain"] + images += config["dark_current_gen"](current, size=shape) # add sky counts - images += config.sky_count_gen(lam=config.sky_level * config.gain, size=shape) / config.gain + images += config["sky_count_gen"](lam=config["sky_level"] * config["gain"], size=shape) / config["gain"] return images @@ -813,17 +812,15 @@ def gen_base_image(cls, config=None, src_cat=None, dtype=np.float32): config = cls.default_config(config) # empty image - base = np.zeros(config.shape, dtype=dtype) - base += config.bias + base = np.zeros(config["shape"], dtype=dtype) + base += config["bias"] base = cls.add_hot_pixels(base, config) base = cls.add_bad_cols(base, config) if src_cat is not None: - add_model_objects(base, src_cat.table, config.model(x_stddev=1, y_stddev=1)) + add_model_objects(base, src_cat.table, config["model"](x_stddev=1, y_stddev=1)) return base - def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, dtype=np.float32,**kwargs): - conf = self.default_config(config=config, **kwargs) - # static objects are added in SimpleImage init - super().__init__(image=self.gen_base_image(conf, dtype=dtype), - config=conf, src_cat=src_cat, obj_cat=obj_cat) + def __init__(self, image=None, src_cat=None, obj_cat=None, dtype=np.float32,**kwargs): + conf = self.default_config(**kwargs) + super().__init__(image=self.gen_base_image(conf, dtype=dtype), src_cat=src_cat, obj_cat=obj_cat, **conf) diff --git a/src/kbmod/mocking/headers.py b/src/kbmod/mocking/headers.py index 6716fb2aa..dd7977cb9 100644 --- a/src/kbmod/mocking/headers.py +++ b/src/kbmod/mocking/headers.py @@ -1,66 +1,123 @@ +import random import warnings +import itertools import numpy as np from astropy.wcs import WCS from astropy.io.fits import Header +from astropy.io.fits.verify import VerifyWarning from .utils import header_archive_to_table from .config import Config __all__ = [ +# "make_wcs", + "WCSFactory", "HeaderFactory", "ArchivedHeader", ] +class WCSFactory: + def __init__(self, mode="static", + pointing=(351., -5), rotation=0, pixscale=0.2, + dither_pos=False, dither_rot=False, dither_amplitudes=(0.01, 0.01, 0.0), + cycle=None): + self.pointing = pointing + self.rotation = rotation + self.pixscale = pixscale -def make_wcs(center_coords=(351., -5.), rotation=0, pixscale=0.2, shape=None): - """ - Create a simple celestial `~astropy.wcs.WCS` object in ICRS - coordinate system. - - Parameters - ---------- - shape : tuple[int] - Two-tuple, dimensions of the WCS footprint - center_coords : tuple[int] - Two-tuple of on-sky coordinates of the center of the WCS in - decimal degrees, in ICRS. - rotation : float, optional - Rotation in degrees, from ICRS equator. In decimal degrees. - pixscale : float - Pixel scale in arcsec/pixel. - - Returns - ------- - wcs : `astropy.wcs.WCS` + self.dither_pos = dither_pos + self.dither_rot = dither_rot + self.dither_amplitudes = dither_amplitudes + self.cycle = cycle + + self.template = self.gen_wcs(self.pointing, self.rotation, self.pixscale) + + self.mode = mode + if dither_pos or dither_rot or cycle is not None: + self.mode = "dynamic" + self.current = 0 + + @classmethod + def gen_wcs(cls, center_coords, rotation, pixscale, shape=None): + """ + Create a simple celestial `~astropy.wcs.WCS` object in ICRS + coordinate system. + + Parameters + ---------- + shape : tuple[int] + Two-tuple, dimensions of the WCS footprint + center_coords : tuple[int] + Two-tuple of on-sky coordinates of the center of the WCS in + decimal degrees, in ICRS. + rotation : float, optional + Rotation in degrees, from ICRS equator. In decimal degrees. + pixscale : float + Pixel scale in arcsec/pixel. + + Returns + ------- + wcs : `astropy.wcs.WCS` The world coordinate system. - Examples - -------- - >>> from kbmod.mocking import make_wcs - >>> shape = (100, 100) - >>> wcs = make_wcs(shape) - >>> wcs = make_wcs(shape, (115, 5), 45, 0.1) - """ - wcs = WCS(naxis=2) - rho = rotation*0.0174533 # deg to rad - scale = 0.1 / 3600.0 # arcsec/pixel to deg/pix - - if shape is not None: - wcs.pixel_shape = shape - wcs.wcs.crpix = [shape[1] / 2, shape[0] / 2] - else: - wcs.wcs.crpix = [0, 0] - wcs.wcs.crval = center_coords - wcs.wcs.cunit = ['deg', 'deg'] - wcs.wcs.cd = [[-scale * np.cos(rho), scale * np.sin(rho)], - [scale * np.sin(rho), scale * np.cos(rho)]] - wcs.wcs.radesys = 'ICRS' - wcs.wcs.ctype = ['RA---TAN', 'DEC--TAN'] - - return wcs + Examples + -------- + >>> from kbmod.mocking import make_wcs + >>> shape = (100, 100) + >>> wcs = make_wcs(shape) + >>> wcs = make_wcs(shape, (115, 5), 45, 0.1) + """ + wcs = WCS(naxis=2) + rho = rotation*0.0174533 # deg to rad + scale = pixscale / 3600.0 # arcsec/pixel to deg/pix + + if shape is not None: + wcs.pixel_shape = shape + wcs.wcs.crpix = [shape[1] // 2, shape[0] // 2] + else: + wcs.wcs.crpix = [0, 0] + wcs.wcs.crval = center_coords + wcs.wcs.cunit = ['deg', 'deg'] + wcs.wcs.pc = np.array([ + [-scale * np.cos(rho), scale * np.sin(rho)], + [scale * np.sin(rho), scale * np.cos(rho)] + ]) + wcs.wcs.radesys = 'ICRS' + wcs.wcs.ctype = ['RA---TAN', 'DEC--TAN'] + + return wcs + + def update_from_header(self, header): + t = self.template.to_header() + t.update(header) + self.template = WCS(t) + + def mock(self, header): + wcs = self.template + + if self.cycle is not None: + wcs = self.cycle[self.current % len(self.cycle)] + + if self.dither_pos: + dra = random.uniform(-self.dither_amplitudes[0], self.dither_amplitudes[0]) + ddec = random.uniform(-self.dither_amplitudes[1], self.dither_amplitudes[1]) + wcs.wcs.crval += [dra, ddec] + if self.dither_rot: + ddec = random.uniform(-self.dither_amplitudes[2], self.dither_amplitudes[2]) + rho = self.dither_amplitudes[2]*0.0174533 # deg to rad + rot_matrix = np.array([ + [-np.cos(rho), np.sin(rho)], + [np.sin(rho), np.cos(rho)] + ]) + new_pc = wcs.wcs.pc @ rot_matrix + wcs.wcs.pc = new_pc + + self.current += 1 + header.update(wcs.to_header()) + return header class HeaderFactory: @@ -68,7 +125,7 @@ class HeaderFactory: "EXTNAME": "PRIMARY", "NAXIS": 0, "BITPIX": 8, - "OBS-MJD": 58914.0, + "DATE-OBS": "2021-03-19T00:27:21.140552", "NEXTEND": 3, "OBS-LAT": -30.166, "OBS-LONG": -70.814, @@ -76,7 +133,7 @@ class HeaderFactory: "OBSERVAT": "CTIO", } - ext_template = {"NAXIS": 2, "NAXIS1": 2048, "NAXIS2": 4096, "CRPIX1": 1024, "CPRIX2": 2048, "BITPIX": 32} + ext_template = {"NAXIS": 2, "NAXIS1": 2048, "NAXIS2": 4096, "CRPIX1": 1024, "CRPIX2": 2048, "BITPIX": 32} def __validate_mutables(self): # !xor @@ -101,30 +158,38 @@ def __validate_mutables(self): "provide the required metadata keys." ) - def __init__(self, metadata, mutables=None, callbacks=None, config=None, **kwargs): + def __init__(self, metadata, mutables=None, callbacks=None, has_wcs=False, wcs_factory=None): cards = [] if metadata is None else metadata self.header = Header(cards=cards) - self.mutables = mutables self.callbacks = callbacks self.__validate_mutables() - self.is_dynamic = self.mutables is not None + self.is_dynamic = mutables is not None + + self.has_wcs = has_wcs + if has_wcs: + self.wcs_factory = WCSFactory() if wcs_factory is None else wcs_factory + self.wcs_factory.update_from_header(self.header) + self.is_dynamic = self.is_dynamic or self.wcs_factory.mode != "static" + self.counter = 0 def mock(self, n=1): headers = [] # This can't be vectorized because callbacks may share global state for i in range(n): - if self.is_dynamic: - header = self.header.copy() - for i, mutable in enumerate(self.mutables): - header[mutable] = self.callbacks[i](header[mutable]) - else: + if not self.is_dynamic: header = self.header + else: + header = self.header.copy() + if self.mutables is not None: + for i, mutable in enumerate(self.mutables): + header[mutable] = self.callbacks[i](header[mutable]) + if self.has_wcs: + header = self.wcs_factory.mock(header) headers.append(header) self.counter += 1 - return headers @classmethod @@ -147,7 +212,8 @@ def from_primary_template(cls, overrides=None, mutables=None, callbacks=None): return cls(hdr, mutables, callbacks) @classmethod - def from_ext_template(cls, overrides=None, mutables=None, callbacks=None, wcs=None, shape=None): + def from_ext_template(cls, overrides=None, mutables=None, callbacks=None, shape=None, + wcs=None): ext_template = cls.ext_template.copy() if shape is not None: @@ -156,14 +222,8 @@ def from_ext_template(cls, overrides=None, mutables=None, callbacks=None, wcs=No ext_template["CRPIX1"] = shape[0] // 2 ext_template["CRPIX2"] = shape[1] // 2 - if wcs is None: - wcs = make_wcs( - shape=(ext_template["NAXIS1"], ext_template["NAXIS2"]), - - ) - - hdr = cls.gen_header(base=ext_template, overrides=overrides, wcs=wcs) - return cls(hdr, mutables, callbacks) + hdr = cls.gen_header(base=ext_template, overrides=overrides) + return cls(hdr, mutables, callbacks, has_wcs=True, wcs_factory=wcs) class ArchivedHeader(HeaderFactory): @@ -183,8 +243,8 @@ class ArchivedHeader(HeaderFactory): format = "ascii.ecsv" - def __init__(self, archive_name, fname, config=None, **kwargs): - super().__init__(config, **kwargs) + def __init__(self, archive_name, fname): + super().__init__({}) self.table = header_archive_to_table(archive_name, fname, self.compression, self.format) # Create HDU groups for easier iteration @@ -220,7 +280,10 @@ def get(self, group_idx): for subgroup in subgroup.groups: header = Header() for k, v, f in subgroup["keyword", "value", "format"]: - header[k] = self.lexical_cast(v, f) + # ignore warnings about non-standard keywords + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=VerifyWarning) + header[k] = self.lexical_cast(v, f) headers.append(header) return headers @@ -230,3 +293,4 @@ def mock(self, n=1): res.append(self.get(self.counter)) self.counter += 1 return res + diff --git a/src/kbmod/standardizers/butler_standardizer.py b/src/kbmod/standardizers/butler_standardizer.py index ada95074e..f5c48dbc7 100644 --- a/src/kbmod/standardizers/butler_standardizer.py +++ b/src/kbmod/standardizers/butler_standardizer.py @@ -294,7 +294,7 @@ def _fetch_meta(self): # photometric analysis of the results, while the effective # values are too often NaN. The URI location itself is # ultimately not very useful, but helpful for data inspection. - if self.config.standardize_metadata: + if self.config["standardize_metadata"]: meta_ref = self.ref.makeComponentRef("metadata") meta = self.butler.get(meta_ref) @@ -311,13 +311,13 @@ def _fetch_meta(self): self._metadata["GAINB"] = meta["GAINB"] # Will be nan for VR filter so it's optional - if self.config.standardize_effective_summary_stats: + if self.config["standardize_effective_summary_stats"]: self._metadata["effTime"] = summary.effTime self._metadata["effTimePsfSigmaScale"] = summary.effTimePsfSigmaScale self._metadata["effTimeSkyBgScale"] = summary.effTimeSkyBgScale self._metadata["effTimeZeroPointScale"] = summary.effTimeZeroPointScale - if self.config.standardize_uri: + if self.config["standardize_uri"]: self._metadata["location"] = self.butler.getURI( self.ref, collections=[ @@ -348,14 +348,14 @@ def standardizeMetadata(self): def standardizeScienceImage(self): self.exp = self.butler.get(self.ref) if self.exp is None else self.exp - zp_correct = 10 ** ((self._metadata["zeroPoint"] - self.config.zero_point) / 2.5) + zp_correct = 10 ** ((self._metadata["zeroPoint"] - self.config["zero_point"]) / 2.5) return [ self.exp.image.array / zp_correct, ] def standardizeVarianceImage(self): self.exp = self.butler.get(self.ref) if self.exp is None else self.exp - zp_correct = 10 ** ((self._metadata["zeroPoint"] - self.config.zero_point) / 2.5) + zp_correct = 10 ** ((self._metadata["zeroPoint"] - self.config["zero_point"]) / 2.5) return [ self.exp.variance.array / zp_correct**2, ] diff --git a/src/kbmod/standardizers/fits_standardizers/test_data_std.py b/src/kbmod/standardizers/fits_standardizers/test_data_std.py index 89b9ebe9a..d5d9d5e04 100644 --- a/src/kbmod/standardizers/fits_standardizers/test_data_std.py +++ b/src/kbmod/standardizers/fits_standardizers/test_data_std.py @@ -80,7 +80,7 @@ def translateHeader(self): """ # required standardizedHeader = {} - obs_datetime = Time(self.primary["OBS-MJD"], format="mjd") + obs_datetime = Time(self.primary["DATE-OBS"]) standardizedHeader["mjd_mid"] = obs_datetime.mjd # optional standardizedHeader["observat"] = self.primary["OBSERVAT"] diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index eb00f4b91..9f5ea7358 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -10,15 +10,26 @@ # from kbmod.wcs_utils import make_fake_wcs # from kbmod.work_unit import WorkUnit -from utils.utils_for_tests import get_absolute_demo_data_path +#from utils.utils_for_tests import get_absolute_demo_data_path #### import unittest +import itertools +import random + +import numpy as np +from numpy.lib.recfunctions import structured_to_unstructured + +from astropy.time import Time +from astropy.table import Table, vstack +from astropy.wcs import WCS +from astropy.coordinates import SkyCoord from kbmod import ImageCollection from kbmod.run_search import SearchRunner from kbmod.configuration import SearchConfiguration +from kbmod.reprojection import reproject_work_unit import kbmod.mocking as kbmock @@ -63,12 +74,12 @@ def test_static_objects(self): class TestRandomLinearSearch(unittest.TestCase): def setUp(self): # Set up shared search values - self.n_imgs = 10 + self.n_imgs = 5 self.repeat_n_times = 10 - self.shape = (300, 300) - self.start_pos = (125, 175) - self.vxs = [-10, 10] - self.vys = [-10, 10] + self.shape = (200, 200) + self.start_pos = (85, 115) + self.vxs = [-20, 20] + self.vys = [-20, 20] # Set up configs for mocking and search # These don't change from test to test @@ -90,58 +101,208 @@ def setUp(self): "max_vx": self.vxs[1], "min_vy": self.vys[0], "max_vy": self.vys[1], - "vx_steps": 50, - "vy_steps": 50, + "vx_steps": 40, + "vy_steps": 40 }, - "num_obs": 10, + "num_obs": self.n_imgs, "do_mask": False, "do_clustering": True, "do_stamp_filter": False, } ) - def test_simple_search(self): + def xmatch_best(self, obj, results, match_cols={"x_mean": "x", "y_mean": "y", "vx": "vx", "vy": "vy"}): + objk, resk = [], [] + for k, v in match_cols.items(): + if k in obj.columns and v in results.table.columns: + objk.append(k) + resk.append(v) + tgt = np.fromiter(obj[tuple(objk)].values(), dtype=float, count=len(objk)) + res = structured_to_unstructured(results[tuple(resk)].as_array(), dtype=float) + diff = np.linalg.norm(tgt-res, axis=1) + if len(results) == 1: + return results[0], diff + return results[diff == diff.min()][0], diff + + def assertResultValuesWithinSpec(self, expected, result, spec, + match_cols={"x_mean": "x", "y_mean": "y", "vx": "vx", "vy": "vy"}): + for ekey, rkey in match_cols.items(): + info = ( + f"\n Expected: \n {expected[tuple(match_cols.keys())]} \n" + f"Retrieved : \n {result[tuple(match_cols.values())]}" + ) + self.assertLessEqual(abs(expected[ekey] - result[rkey]), spec, info) + + def run_single_search(self, data, expected, spec=5): + ic = ImageCollection.fromTargets(data, force="TestDataStd") + wu = ic.toWorkUnit(search_config=self.config) + results = SearchRunner().run_search_from_work_unit(wu) + + # Run tests + self.assertGreaterEqual(len(results), 1) + for obj in expected.table: + res, dist = (results[0], None) if len(results)==1 else self.xmatch_best(obj, results) + self.assertResultValuesWithinSpec(obj, res, spec) + + def test_exact_motion(self): + search_vs = list(itertools.product([-20, 0, 20], repeat=2)) + search_vs.remove((0, 0)) + for (vx, vy) in search_vs: + with self.subTest(f"Cardinal direction: {(vx, vy)}"): + self.config._params["generator_config"] = { + "name": "SingleVelocitySearch", + "vx": vx, + "vy": vy + } + obj_cat = kbmock.ObjectCatalog.from_defaults(self.param_ranges, n=1) + obj_cat.table["vx"] = vx + obj_cat.table["vy"] = vy + factory = kbmock.SimpleFits(shape=self.shape, step_t=1, obj_cat=obj_cat) + hduls = factory.mock(n=self.n_imgs) + self.run_single_search(hduls, obj_cat, 1) + + def test_random_motion(self): # Mock the data and repeat tests. The random catalog # creation guarantees a diverse set of changing test values for i in range(self.repeat_n_times): - with self.subTest(n=i): + with self.subTest(f"Iteration {i}"): obj_cat = kbmock.ObjectCatalog.from_defaults(self.param_ranges, n=1) - factory = kbmock.SimpleFits(shape=self.shape, step_mjd=1, obj_cat=obj_cat) + factory = kbmock.SimpleFits(shape=self.shape, step_t=1, obj_cat=obj_cat) hduls = factory.mock(n=self.n_imgs) + self.run_single_search(hduls, obj_cat) + + def test_reprojected_search(self): + # 0. Setup + self.shape = (500, 500) + self.start_pos = (10, 10) # (ra, dec) in deg + n_obj = 1 + pixscale = 0.2 + timestamps = Time(np.arange(58915, 58915+self.n_imgs, 1), format="mjd") + vx = 0.001 # degrees / day (given the timestamps) + vy = 0.001 + + # 1. Mock data + # - mock catalogs, set expected positions by hand + # - mock WCSs so that they dither around (10, 10) + # - instantiate the required mockers and mock + cats = [] + for i, t in enumerate(timestamps): + cats.append( + Table({ + "amplitude": [100], + "obstime": [t], + "ra_mean": [self.start_pos[0] + vx*i], + "dec_mean": [self.start_pos[1] + vy*i], + "stddev": [2.0] + })) + catalog = vstack(cats) + obj_cat = kbmock.ObjectCatalog.from_table(catalog, kind="world", mode="folding") + + wcs_factory = kbmock.WCSFactory( + pointing=self.start_pos, + rotation=0, + pixscale=pixscale, + dither_pos=True, + dither_rot=True, + dither_amplitudes=(0.001, 0.001, 0.01) + ) - ic = ImageCollection.fromTargets(hduls, force="TestDataStd") - wu = ic.toWorkUnit(search_config=self.config) - results = SearchRunner().run_search_from_work_unit(wu) - - # Run tests - self.assertGreaterEqual(len(results), 1) - for res in results: - diff = abs(obj_cat.table["y_mean"] - res["y"]) - obj = obj_cat.table[diff == diff.min()] - self.assertLessEqual(abs(obj["x_mean"] - res["x"]), 5) - self.assertLessEqual(abs(obj["y_mean"] - res["y"]), 5) - self.assertLessEqual(abs(obj["vx"] - res["vx"]), 5) - self.assertLessEqual(abs(obj["vy"] - res["vy"]), 5) - - def test_diffim_mocks(self): - src_cat = kbmock.SourceCatalog.from_defaults() - obj_cat = kbmock.ObjectCatalog.from_defaults(self.param_ranges, n=1) - factory = kbmock.DECamImdiff.from_defaults(with_data=True, src_cat=src_cat, obj_cat=obj_cat) + prim_hdr_factory = kbmock.HeaderFactory.from_primary_template( + mutables=["DATE-OBS"], + callbacks=[kbmock.ObstimeIterator(timestamps)], + ) + + factory = kbmock.SimpleFits( + shape=self.shape, + obj_cat=obj_cat, + wcs_factory=wcs_factory + ) + factory.prim_hdr = prim_hdr_factory hduls = factory.mock(n=self.n_imgs) + # 2. Run search + # - make an IC + # - determine WCS footprint to reproject to + # - determine the pixel-based velocity to search for + # - reproject + # - run search ic = ImageCollection.fromTargets(hduls, force="TestDataStd") - wu = ic.toWorkUnit(search_config=self.config) - results = SearchRunner().run_search_from_work_unit(wu) - # Run tests + from reproject.mosaicking import find_optimal_celestial_wcs + opt_wcs, self.shape = find_optimal_celestial_wcs(list(ic.wcs)) + opt_wcs.array_shape = self.shape + + meanvx = -vx * 3600 / pixscale + meanvy = vy * 3600 / pixscale + + # The velocity grid needs to be searched very densely for the realistic + # case (compared to the fact the velocity spread is not that large), and + # we'll still end up ~10 pixels away from the truth. + search_config = SearchConfiguration.from_dict({ + "generator_config": { + "name": "VelocityGridSearch", + "min_vx": meanvx-5, + "max_vx": meanvx+5, + "min_vy": meanvy-5, + "max_vy": meanvy+5, + "vx_steps": 40, + "vy_steps": 40 + }, + "num_obs": 1, + "do_mask": False, + "do_clustering": True, + "do_stamp_filter": False, + }) + wu = ic.toWorkUnit(search_config) + repr_wu = reproject_work_unit(wu, opt_wcs, parallelize=False) + results = SearchRunner().run_search_from_work_unit(repr_wu) + + # Compare results and validate + # - add in pixel velocities because realistic searches rarely + # find good pixel location match + # - due to that, we also can't rely that we'll get a good match on + # any particular catalog realization. We iterate over all of them + # and find the best matching results in each realization. + # From all realizations find the one that matches the best. + # Select that realization and that best match for comparison. + cats = obj_cat.mock(t=timestamps, wcs=[opt_wcs]*self.n_imgs) + for cat in cats: + cat["vx"] = meanvx + cat["vy"] = meanvy + + dists = np.array([self.xmatch_best(cat, results)[1] for cat in cats]) + min_dist_within_realization = dists.min(axis=0) + min_dist_across_realizations = dists.min() + + best_realization = dists.min(axis=1) == min_dist_across_realizations + best_realization_idx = np.where(best_realization == True)[0][0] + + best_cat = cats[best_realization_idx] + best_res = results[dists[best_realization_idx] == min_dist_across_realizations] + self.assertGreaterEqual(len(results), 1) - for res in results: - diff = abs(obj_cat.table["y_mean"] - res["y"]) - obj = obj_cat.table[diff == diff.min()] - self.assertLessEqual(abs(obj["x_mean"] - res["x"]), 5) - self.assertLessEqual(abs(obj["y_mean"] - res["y"]), 5) - self.assertLessEqual(abs(obj["vx"] - res["vx"]), 5) - self.assertLessEqual(abs(obj["vy"] - res["vy"]), 5) + self.assertResultValuesWithinSpec(best_cat, best_res, 10) + +# def test_diffim_mocks(self): +# src_cat = kbmock.SourceCatalog.from_defaults() +# obj_cat = kbmock.ObjectCatalog.from_defaults(self.param_ranges, n=1) +# factory = kbmock.DECamImdiff.from_defaults(with_data=True, src_cat=src_cat, obj_cat=obj_cat) +# hduls = factory.mock(n=self.n_imgs) +# +# +# ic = ImageCollection.fromTargets(hduls, force="TestDataStd") +# wu = ic.toWorkUnit(search_config=self.config) +# results = SearchRunner().run_search_from_work_unit(wu) +# +# # Run tests +# self.assertGreaterEqual(len(results), 1) +# for res in results: +# diff = abs(obj_cat.table["y_mean"] - res["y"]) +# obj = obj_cat.table[diff == diff.min()] +# self.assertLessEqual(abs(obj["x_mean"] - res["x"]), 5) +# self.assertLessEqual(abs(obj["y_mean"] - res["y"]), 5) +# self.assertLessEqual(abs(obj["vx"] - res["vx"]), 5) +# self.assertLessEqual(abs(obj["vy"] - res["vy"]), 5) #### diff --git a/tests/test_mocking.py b/tests/test_mocking.py index 8c159335b..b03ce580e 100644 --- a/tests/test_mocking.py +++ b/tests/test_mocking.py @@ -1,6 +1,10 @@ import unittest import numpy as np + +import astropy.units as u +from astropy.wcs import WCS +from astropy.time import Time from astropy.table import Table, vstack import kbmod.mocking as kbmock @@ -23,14 +27,14 @@ def test(self): self.assertTrue((hdul["VARIANCE"].data == zeros).all()) self.assertTrue((hdul["MASK"].data == zeros).all()) - factory = kbmock.EmptyFits(shape=(10, 100), step_mjd=1) + factory = kbmock.EmptyFits(shape=(10, 100), step_t=1) hduls = factory.mock(2) hdul = hduls[0] self.assertEqual(hdul["IMAGE"].data.shape, (10, 100)) self.assertEqual(hdul["VARIANCE"].data.shape, (10, 100)) self.assertEqual(hdul["MASK"].data.shape, (10, 100)) - dt = hduls[1]["PRIMARY"].header["OBS-MJD"] - hduls[0]["PRIMARY"].header["OBS-MJD"] - self.assertEqual(dt, 1) + dt = Time(hduls[1]["PRIMARY"].header["DATE-OBS"]) - Time(hduls[0]["PRIMARY"].header["DATE-OBS"]) + self.assertEqual(dt.to("day").value, 1) with self.assertRaisesRegex(ValueError, "destination is read-only"): hdul["IMAGE"].data[0, 0] = 0 @@ -54,6 +58,14 @@ def test(self): class TestSimpleFits(unittest.TestCase): + def setUp(self): + self.n_obj = 5 + self.n_imgs = 3 + self.shape = (100, 300) + self.padded = ((10, 90), (10, 290)) + self.timestamps = Time(np.arange(58915, 58915+self.n_imgs, 1), format="mjd") + self.step_t = 1 + def test(self): """Test basic functionality of SimpleFits factory.""" factory = kbmock.SimpleFits() @@ -70,131 +82,175 @@ def test(self): self.assertTrue((hdul["VARIANCE"].data == zeros).all()) self.assertTrue((hdul["MASK"].data == zeros).all()) - factory = kbmock.SimpleFits(shape=(10, 100), step_mjd=1) + factory = kbmock.SimpleFits(shape=(10, 100), step_t=1) hduls = factory.mock(2) hdul = hduls[0] self.assertEqual(hdul["IMAGE"].data.shape, (10, 100)) self.assertEqual(hdul["VARIANCE"].data.shape, (10, 100)) self.assertEqual(hdul["MASK"].data.shape, (10, 100)) - step_mjd = hduls[1]["PRIMARY"].header["OBS-MJD"] - hduls[0]["PRIMARY"].header["OBS-MJD"] - self.assertEqual(step_mjd, 1) + step_t = Time(hduls[1]["PRIMARY"].header["DATE-OBS"]) - Time(hduls[0]["PRIMARY"].header["DATE-OBS"]) + self.assertEqual(step_t.to("day").value, 1.0) def test_static_src_cat(self): src_cat = kbmock.SourceCatalog.from_defaults() src_cat2 = kbmock.SourceCatalog.from_defaults() - self.assertEqual(src_cat.config.mode, "static") + self.assertEqual(src_cat.config["mode"], "static") self.assertFalse((src_cat.table == src_cat2.table).all()) - src_cat = kbmock.SourceCatalog.from_defaults(n=3) - self.assertEqual(len(src_cat.table), 3) + src_cat = kbmock.SourceCatalog.from_defaults(n=self.n_obj) + self.assertEqual(len(src_cat.table), self.n_obj) - shape = (300, 500) param_ranges = { "amplitude": [100, 100], - "x_mean": (100, 200), - "y_mean": (50, 80), + "x_mean": self.padded[1], + "y_mean": self.padded[0], "x_stddev": [2.0, 2.0], "y_stddev": [2.0, 2.0], } src_cat = kbmock.SourceCatalog.from_defaults(param_ranges, seed=100) src_cat2 = kbmock.SourceCatalog.from_defaults(param_ranges, seed=100) self.assertTrue((src_cat.table == src_cat2.table).all()) - self.assertLessEqual(src_cat.table["x_mean"].max(), shape[1]) - self.assertLessEqual(src_cat.table["y_mean"].max(), shape[0]) + self.assertLessEqual(src_cat.table["x_mean"].max(), self.shape[1]) + self.assertLessEqual(src_cat.table["y_mean"].max(), self.shape[0]) - factory = kbmock.SimpleFits(shape=shape, src_cat=src_cat) + factory = kbmock.SimpleFits(shape=self.shape, src_cat=src_cat) hdul = factory.mock()[0] - for x, y in src_cat.table["x_mean", "y_mean"]: - # Can only test greater or equal because objects may overlap - self.assertGreaterEqual(hdul["IMAGE"].data[int(y), int(x)], 80) + x = np.round(src_cat.table["x_mean"].data).astype(int) + y = np.round(src_cat.table["y_mean"].data).astype(int) + self.assertGreaterEqual(hdul["IMAGE"].data[y, x].min(), 90) + + def validate_cat_render(self, hduls, cats, expected_gte=90): + for hdul, cat in zip(hduls, cats): + x = np.round(cat["x_mean"].data).astype(int) + y = np.round(cat["y_mean"].data).astype(int) + self.assertGreaterEqual(hdul["IMAGE"].data[y, x].min(), expected_gte) + self.assertGreaterEqual(hdul["VARIANCE"].data[y, x].min(), expected_gte) def test_progressive_obj_cat(self): obj_cat = kbmock.ObjectCatalog.from_defaults() obj_cat2 = kbmock.ObjectCatalog.from_defaults() - self.assertEqual(obj_cat.config.mode, "progressive") + self.assertEqual(obj_cat.config["mode"], "progressive") self.assertFalse((obj_cat.table == obj_cat2.table).all()) - obj_cat = kbmock.ObjectCatalog.from_defaults(n=3) - self.assertEqual(len(obj_cat.table), 3) + obj_cat = kbmock.ObjectCatalog.from_defaults(n=self.n_obj) + self.assertEqual(len(obj_cat.table), self.n_obj) - shape = (300, 500) param_ranges = { "amplitude": [100, 100], - "x_mean": (0, 50), - "y_mean": (50, shape[0]), + "x_mean": (0, 90), + "y_mean": self.padded[0], "x_stddev": [2.0, 2.0], "y_stddev": [2.0, 2.0], - "vx": [100, 300], + "vx": [10, 20], "vy": [0, 0], } - obj_cat = kbmock.ObjectCatalog.from_defaults(param_ranges, seed=100) - obj_cat2 = kbmock.ObjectCatalog.from_defaults(param_ranges, seed=100) + seed = 200 + obj_cat = kbmock.ObjectCatalog.from_defaults(param_ranges, seed=seed) + obj_cat2 = kbmock.ObjectCatalog.from_defaults(param_ranges, seed=seed) self.assertTrue((obj_cat.table == obj_cat2.table).all()) - self.assertLessEqual(obj_cat.table["x_mean"].max(), 50) - self.assertLessEqual(obj_cat.table["y_mean"].max(), shape[0]) - step_mjd = 0.1 - obj_cat = kbmock.ObjectCatalog.from_defaults(param_ranges, n=5, seed=100) - factory = kbmock.SimpleFits(shape=shape, step_mjd=step_mjd, obj_cat=obj_cat) - hduls = factory.mock(n=5) + obj_cat = kbmock.ObjectCatalog.from_defaults(param_ranges, n=self.n_obj) + factory = kbmock.SimpleFits(shape=self.shape, step_t=self.step_t, obj_cat=obj_cat) + hduls = factory.mock(n=self.n_imgs) obj_cat.reset() - for i in range(5): - newcat = obj_cat.mock(dt=step_mjd)[0] - for x, y in newcat["x_mean", "y_mean"]: - if x < shape[1] and y < shape[0]: - # Can only test greater or equal because objects may overlap - # probably a rounding-off error while moving drops 1 flux count - self.assertGreaterEqual(hduls[i]["IMAGE"].data[int(y), int(x)], 79) - self.assertGreaterEqual(hduls[i]["VARIANCE"].data[int(y), int(x)], 0) + cats = obj_cat.mock(n=self.n_imgs, dt=self.step_t) + self.validate_cat_render(hduls, cats) def test_folding_obj_cat(self): - nobj = 5 - shape = (300, 300) - timestamps = np.arange(58915, 58920, 1) - - start_x = np.ones((nobj,)) * 10 - start_y = np.linspace(10, shape[0] - 10, nobj) + # Set up shared values for the whole setup + # like starting positions of object and timestamps + start_x = np.ones((self.n_obj,)) * 10 + start_y = np.linspace(10, self.shape[0] - 10, self.n_obj) + # Set up non-linear catalog (objects will move as counter^2*v) cats = [] - for i, t in enumerate(timestamps): + for i, t in enumerate(self.timestamps): cats.append( - Table( - { - "amplitude": [100] * nobj, - "obstime": [t] * nobj, - "x_mean": start_x + 15 * i * i, - "y_mean": start_y, - "stddev": [2.0] * nobj, - } - ) - ) + Table({ + "amplitude": [100] * self.n_obj, + "obstime": [t] * self.n_obj, + "x_mean": start_x + 15 * i * i, + "y_mean": start_y, + "stddev": [2.0] * self.n_obj, + })) catalog = vstack(cats) - obj_cat = kbmock.ObjectCatalog.from_table(catalog) - obj_cat.mode = "folding" + # Mock data based on that catalog + obj_cat = kbmock.ObjectCatalog.from_table(catalog, mode="folding") prim_hdr_factory = kbmock.HeaderFactory.from_primary_template( - mutables=["OBS-MJD"], - callbacks=[ - kbmock.ObstimeIterator(timestamps), - ], + mutables=["DATE-OBS"], + callbacks=[kbmock.ObstimeIterator(self.timestamps)], ) - factory = kbmock.SimpleFits(shape=shape, obj_cat=obj_cat) + factory = kbmock.SimpleFits(shape=self.shape, obj_cat=obj_cat) factory.prim_hdr = prim_hdr_factory - hduls = factory.mock(n=len(timestamps)) + hduls = factory.mock(n=self.n_imgs) + # Run tests and ensure we have rendered the object in correct + # positions obj_cat.reset() - cats = obj_cat.mock(t=timestamps) - for hdul, cat in zip(hduls, cats): - for x, y in cat["x_mean", "y_mean"]: - if x < shape[1] and y < shape[0]: - # Can only test greater or equal because objects may overlap - # probably a rounding-off error while moving drops 1 flux count - self.assertGreaterEqual(hdul["IMAGE"].data[int(y), int(x)], 79) - self.assertGreaterEqual(hdul["VARIANCE"].data[int(y), int(x)], 0) + cats = obj_cat.mock(n=self.n_imgs, t=self.timestamps) + self.validate_cat_render(hduls, cats) + + def test_progressive_sky_cat(self): + # a 10-50 in x by a 10-90 in y box using default WCS + #self.shape = (500, 500) + param_ranges = { + "ra_mean": (350.998, 351.002), + "dec_mean": (-5.0077, -5.0039), + "v_ra": [-0.001, 0.0001], + "v_dec": [0, 0], + "amplitude": [100, 100], + "x_stddev": [2.0, 2.0], + "y_stddev": [2.0, 2.0], + } + catalog = kbmock.gen_random_catalog(self.n_obj, param_ranges) + obj_cat = kbmock.ObjectCatalog.from_table(catalog, kind="world") + + factory = kbmock.SimpleFits(shape=self.shape, step_t=self.step_t, obj_cat=obj_cat) + hduls = factory.mock(n=self.n_imgs) + + # Run tests and ensure we have rendered the object in correct + # positions + obj_cat.reset() + wcs = [WCS(h["IMAGE"].header) for h in hduls] + cats = obj_cat.mock(n=self.n_imgs, dt=self.step_t, wcs=wcs) + self.validate_cat_render(hduls, cats) + + def test_folding_sky_cat(self): + # a 20x20 box in pixels using a default WCS + start_ra = np.linspace(350.998, 351.002, self.n_obj) + start_dec = np.linspace(-5.0077, -5.0039, self.n_obj) + + cats = [] + for i, t in enumerate(self.timestamps): + cats.append( + Table({ + "amplitude": [100] * self.n_obj, + "obstime": [t] * self.n_obj, + "ra_mean": start_ra - 0.001*i, + "dec_mean": start_dec,# + 0.00011 * i, + "stddev": [2.0] * self.n_obj + })) + catalog = vstack(cats) + obj_cat = kbmock.ObjectCatalog.from_table(catalog, kind="world", mode="folding") + + prim_hdr_factory = kbmock.HeaderFactory.from_primary_template( + mutables=["DATE-OBS"], + callbacks=[kbmock.ObstimeIterator(self.timestamps)], + ) + + factory = kbmock.SimpleFits(shape=self.shape, obj_cat=obj_cat) + factory.prim_hdr = prim_hdr_factory + hduls = factory.mock(n=self.n_imgs) + + obj_cat.reset() + wcs = [WCS(h[1].header) for h in hduls] + cats = obj_cat.mock(n=self.n_imgs, t=self.timestamps, wcs=wcs) + self.validate_cat_render(hduls, cats) # TODO: move to pytest and mark as xfail def test_noise_gen(self): @@ -219,5 +275,43 @@ def test_noise_gen(self): self.assertAlmostEqual(hdul["IMAGE"].data.std(), 2, 1) +class TestDiffIm(unittest.TestCase): + def test(self): + """Test basic functionality of SimpleFits factory.""" + factory = kbmock.DECamImdiff() + hduls = factory.mock(2) + + names = [ + "IMAGE", + "MASK", + "VARIANCE", + "ARCHIVE_INDEX", + "FilterLabel", + "Detector", + "TransformMap", + "ExposureSummaryStats", + "Detector", + "KernelPsf", + "FixedKernel", + "SkyWcs", + "ApCorrMap", + "ChebyshevBoundedField", + "ChebyshevBoundedField" + ] + hdul = hduls[0] + self.assertEqual(len(hduls), 2) + self.assertEqual(len(hduls[0]), 16) + for name, hdu in zip(names, hdul[1:]): + self.assertEqual(name, hdu.name) + self.assertEqual(hdul["PRIMARY"].data, None) + + factory = kbmock.DECamImdiff(with_data=True) + hduls = factory.mock(2) + hdul = hduls[0] + self.assertEqual(hdul["IMAGE"].data.shape, (2048, 4096)) + self.assertEqual(hdul["VARIANCE"].data.shape, (2048, 4096)) + self.assertEqual(hdul["MASK"].data.shape, (2048, 4096)) + + if __name__ == "__main__": unittest.main()