From f25f0dea6f4602b383e38b6540085fb36f393731 Mon Sep 17 00:00:00 2001 From: DinoBektesevic Date: Tue, 9 Apr 2024 13:39:00 -0700 Subject: [PATCH] Refactor data factories to use the config class. --- src/kbmod/mocking/fits_data.py | 361 ++++++++++++++++++++++----------- 1 file changed, 247 insertions(+), 114 deletions(-) diff --git a/src/kbmod/mocking/fits_data.py b/src/kbmod/mocking/fits_data.py index 5519ed93..fde6fa68 100644 --- a/src/kbmod/mocking/fits_data.py +++ b/src/kbmod/mocking/fits_data.py @@ -1,10 +1,23 @@ import numpy as np -from astropy.io.fits import PrimaryHDU, CompImageHDU, ImageHDU, BinTableHDU, TableHDU +from astropy.io.fits import ( + PrimaryHDU, + CompImageHDU, + ImageHDU, + BinTableHDU, + TableHDU +) from astropy.modeling import models -from astropy.convolution import discretize_model +from .config import Config, ConfigurationError -__all__ = ["add_model_objects", "DataFactory", "ZeroedData", "SimpleImage", "SimpleMask"] + +__all__ = [ + "add_model_objects", + "DataFactory", + "ZeroedData", + "SimpleImage", + "SimpleMask" +] def add_model_objects(img, catalog, model): @@ -40,53 +53,89 @@ def add_model_objects(img, catalog, model): return img + + +class DataFactoryConfig(Config): + writeable = False + """Sets the base array writeable flag.""" + + copy_base = False + """ + When `True`, a copy of the base data object is passed into the generate + method, otherwise, the original (possibly mutable!) base data object is + given. + """ + + copy_mocked = False + """ + When `True`, the `DataFactory.mock` returns a copy of the final object, + otherwise the original (possibly mutable!) object is returned. + """ + + return_copy = False + + + isStatic = False + """ + When `False` the `DataFactory.mock` will generate new data every time + it is called. Otherwise it will memoize the result of the first time + it's called and return that result directly or as a copy. + """ + + class DataFactory: - # https://archive.stsci.edu/fits/fits_standard/node39.html#s:man - bitpix_type_map = { - # or char - 8: int, - # actually no idea what dtype, or C type for that matter, - # are used to represent these values. But default Headers return them - 16: np.float16, - 32: np.float32, - 64: np.float64, - # classic IEEE float and double - -32: np.float32, - -64: np.float64, - } + default_config = DataFactoryConfig - def __init__(self, base=None, return_copy=False, mutable=False, **kwargs): + def __init__(self, base=None, config=None, **kwargs): # not sure if this is "best" "best" way, but it does safe a lot of # array copies if we don't have to write to the mocked array # (which we shouldn't?). To be safe we set the writable flag to False # by default - self.return_copy = return_copy - self.mutable = mutable + self.config = self.default_config() + self.config.update(config, **kwargs) + self.base = base if base is not None: - self.base.flags.writeable = mutable + self.base.flags.writeable = self.config.writeable + self.counter = 0 def mock(self, hdu=None, **kwargs): - if self.return_copy: + self.counter += 1 + if self.config.return_copy: return self.base.copy() return self.base + +class SimpleVarianceConfig(DataFactoryConfig): + read_noise = 0.0 + gain = 1.0 + calculate_base = True + class SimpleVariance(DataFactory): - def __init__(self, image, read_noise, gain, return_copy=False, mutable=False): - self.read_noise = read_noise - self.gain = gain - super().__init__(base=image / gain + read_noise**2, return_copy=return_copy, mutable=mutable) + default_config = SimpleVarianceConfig + + def __init__(self, image=None, config=None, **kwargs): + # skip setting the base here since the real base is + # derived from given image we just set it manually + super().__init__(base=None, config=config, **kwargs) + + if image is not None: + self.base = image/self.config.gain + self.config.read_noise**2 def mock(self, images=None): if images is None: return self.base - return images/self.gain + self.read_noise**2 + return images/self.config.gain + self.config.read_noise**2 +class SimpleMaskConfig(DataFactoryConfig): + pass + class SimpleMask(DataFactory): - def __init__(self, mask, return_copy=False, mutable=False): - super().__init__(mask, return_copy, mutable) + default_config = SimpleMaskConfig + def __init__(self, mask, config=None, **kwargs): + super().__init__(base=mask, config=config, **kwargs) @classmethod def from_params(cls, shape, padding=0, bad_columns=[]): @@ -119,18 +168,40 @@ def from_patches(cls, shape, patches): return cls(mask) +class ZeroedDataConfig(DataFactoryConfig): + shape = (5, 5) + """Default image size.""" + + # https://archive.stsci.edu/fits/fits_standard/node39.html#s:man + bitpix_type_map = { + # or char + 8: int, + # actually no idea what dtype, or C type for that matter, + # are used to represent these values. But default Headers return them + 16: np.float16, + 32: np.float32, + 64: np.float64, + # classic IEEE float and double + -32: np.float32, + -64: np.float64, + } + """Map between FITS header BITPIX keyword value and NumPy return type.""" + class ZeroedData(DataFactory): - def __init__(self, base=None, return_copy=False, mutable=False): - super().__init__(base, return_copy, mutable) + default_config = ZeroedDataConfig + + def __init__(self, base=None, config=None, **kwargs): + super().__init__(base, config, **kwargs) def mock_image_data(self, hdu): cols = hdu.header.get("NAXIS1", False) rows = hdu.header.get("NAXIS2", False) + shape = (cols, rows) if all((cols, rows)) else self.config.shape - cols = 5 if not cols else cols - rows = 5 if not rows else rows - - data = np.zeros((cols, rows), dtype=self.bitpix_type_map[hdu.header["BITPIX"]]) + data = np.zeros( + shape, + dtype=self.config.bitpix_type_map[hdu.header["BITPIX"]] + ) return data def mock_table_data(self, hdu): @@ -154,137 +225,199 @@ def mock(self, hdu=None, **kwargs): raise TypeError(f"Expected an HDU, got {type(hdu)} instead.") -class SimpleImage(DataFactory): - rng = np.random.default_rng() - noise_gen = rng.standard_normal + +class SimpleImageConfig(DataFactoryConfig): + shape = (100, 100) + seed = None + noise = 0 + noise_std = 1.0 model = models.Gaussian2D - def __init__(self, image=None, shape=(1000, 1000), noise=0, noise_std=1.0, src_cat=None, **kwargs): + +class SimpleImage(DataFactory): + default_config = SimpleImageConfig + + def __init__(self, image=None, config=None, src_cat=None, **kwargs): + super().__init__(image, config, **kwargs) + if image is None: - image = np.zeros(shape, dtype=np.float32) - self.shape = shape + image = np.zeros(self.config.shape, dtype=np.float32) else: image = image - self.shape = image.shape + self.config.shape = image.shape if src_cat is not None: - add_model_objects(image, src_cat.table, self.model) - - super().__init__(image, False, False) + add_model_objects(image, src_cat.table, self.config.model) + self.base = image self._base_contains_data = image.sum() != 0 - self.noise = noise - self.noise_std = noise_std + + @classmethod + def add_noise(cls, n, images, config): + rng = np.random.default_rng(seed=config.seed) + shape = images.shape + + # noise has to be resampled for every image + rng.standard_normal(size=shape, dtype=np.float32, out=images) + + # There's a lot of multiplications that happen, skip if possible + if self.config.noise_std != 1.0: + images *= config.noise_std + images += config.noise + + return images def mock(self, n=1, obj_cats=None, **kwargs): - shape = (n, *self.shape) + shape = (n, *self.config.shape) images = np.zeros(shape, dtype=np.float32) - if self.noise != 0: - rng = np.random.default_rng() - rng.standard_normal(size=shape, dtype=np.float32, out=images) - if self.noise_std != 1.0: - images *= self.noise_std - images += self.noise + if self.config.noise != 0: + images = self.gen_noise(n, images, self.config) + # but if base has no data (no sources, bad cols etc) skip if self._base_contains_data: images += self.base + # same with moving objects if obj_cats is not None: for i, (img, cat) in enumerate(zip(images, obj_cats)): - add_model_objects(img, cat, self.model(x_stddev=1, y_stddev=1)) + add_model_objects( + img, + cat, + self.config.model(x_stddev=1, y_stddev=1) + ) return images -class SimulatedImage(DataFactory): + + +class SimulatedImageConfig(DataFactoryConfig): + # not sure this is a smart idea to put here rng = np.random.default_rng() + + # image properties + shape = (100, 100) + + # detector properties read_noise_gen = rng.normal + read_noise = 0 + gain = 1.0 + bias = 0.0 + + add_bad_cols = False + bad_cols_method = "random" + bad_col_locs = [] # for manual setting of cols + n_bad_cols = 5 + bad_cols_seed = 123 + bad_col_pattern_offset = 0.1 + dark_current_gen = rng.poisson + dark_current = 0 + + add_hot_pixels = False + hot_pix_method = "random" + hot_pix_locs = [] + hot_pix_seed = 321 + n_hot_pix = 10 + hot_pix_offset = 1000 + + # Observation properties + exposure_time = 120.0 sky_count_gen = rng.poisson + sky_level = 0 + + # Object and Source properties + model = models.Gaussian2D - def __add_bad_cols(self, image, cols, bias, n, seed, pattern_offset): - if not cols: - # most of the time I imagine we don't need bad cols + +class SimulatedImage(SimpleImage): + default_config = SimulatedImageConfig + + @classmethod + def add_bad_cols(cls, image, config): + if not config.add_bad_cols: return image - if bad_columns == "random": + if config.bad_cols_method == "random": rng = np.random.RandomState(seed=self.bad_cols_seed) - bad_cols = rng.randint(0, shape[1], size=n) + bad_cols = rng.randint(0, shape[1], size=config.n_bad_cols) + elif config.bad_col_locs: + bad_cols = config.bad_col_locs else: - bad_cols = bad_columns + raise ConfigurationError("Bad columns method is not 'random', but `bad_col_locs` contains no bad column indices.") + + self.col_pattern = rng.randint( + low=0, + high=int(config.bad_col_pattern_offset * config.bias), + size=shape[0] + ) - self.col_pattern = rng.randint(0, int(pattern_offset * bias), size=shape[0]) for col in columns: - image[:, col] = bias + col_pattern + image[:, col] = config.bias + col_pattern return image - def __add_hot_pixels(self, image, pixels, percent, offset): - if not pixels: - # most of the time I imagine we don't need hot pixels + @classmethod + def add_hot_pixels(cls, image, config): + if not config.add_hot_pixels: return image - if pixels == "random": - rng = np.random.RandomState(seed=self.hot_pixel_seed) - n_pixels = image.shape[0] * image.shape[1] - n = int(percent * n_pixels) - x = rng.randint(0, shape[1], size=n) - y = rng.randint(0, shape[0], size=n) + if config.hot_pix_method == "random": + rng = np.random.RandomState(seed=config.hot_pix_seed) + x = rng.randint(0, shape[1], size=config.n_hot_pix) + y = rng.randint(0, shape[0], size=config.n_hot_pix) hot_pixels = np.column_stack(x, y) - else: + elif config.hot_pix_locs: hot_pixels = pixels + else: + raise ConfigurationError("Hot pixels method is not 'random', but `hot_pix_locs` contains no (col, row) location indices of hot pixels.") for pix in hot_pixels: image[*pix] += image[*pix] * offset return image - def __init__( - self, - shape, - read_noise, - gain, - bias, - exposure_time, - dark_current, - sky_level, - bad_columns=False, - bad_cols_seed=321, - n_bad_cols=5, - bad_col_pattern_offset=0.1, - hot_pixels=False, - hot_pix_seed=543, - hot_percent=0.00001, - hot_offset=1000, - ): - super().__init__() - - # starting empty image - self.base_img = np.zeros(shape) + @classmethod + def add_noise(cls, n, images, config): + shape = images.shape # add read noise - self.base_img += self.read_noise_gen(scale=read_noise / gain, size=shape) - - # add bias - self.base_img += bias - - # add bad columns - self.bad_cols_seed = bad_cols_seed - self.__add_bad_cols(self.base_image, bad_columns, bias, n_bad_cols, bad_cols_seed) + images += self.read_noise_gen( + scale=config.read_noise / config.gain, + size=shape + ) # add dark current - current = dark_current * exposure_time / gain - self.base_img = self.dark_current_gen(current, size=shape) - - # add hot pixels - self.hot_pixel_seed = hot_pix_seed - self.base_img = self.__add_hot_pixels(self.base_img, hot_pixels, hot_percent, hot_offset) + current = config.dark_current * config.exposure_time / config.gain + images += config.dark_current_gen(current, size=shape) # add sky counts - self.base_img += self.sky_count_gen(sky_level * gain, shape) / gain + images += self.sky_count_gen( + lam=config.sky_level * config.gain, + size=shape + ) / config.gain - def mock(self, hdu=None, **kwargs): - # we do always have to return a new copy here, since sci images - # are expected to be written on - return self.base_img.copy() + return images + + @classmethod + def gen_base_image(cls, config=None): + config = cls.default_config(config) + + # empty image + base = np.zeros(config.shape, dtype=np.float32) + base += config.bias + base = cls.add_bad_cols(base, config) + base = cls.add_hot_pixels(base, config) + + return base + + def __init__(self, image=None, config=None, src_cat=None): + conf = self.default_config(config) + base = self.gen_base_image(conf) + super().__init__(base, conf, src_cat) + +# def mock(self, n=1, obj_cats=None, **kwargs): +# # we do always have to return a new copy here, since sci images +# # are expected to be written on +# return self.base_img.copy()