Skip to content

Commit

Permalink
Fixup the mocking code.
Browse files Browse the repository at this point in the history
  • Loading branch information
DinoBektesevic committed Apr 23, 2024
1 parent 2d99408 commit d797271
Show file tree
Hide file tree
Showing 6 changed files with 906 additions and 402 deletions.
1 change: 0 additions & 1 deletion src/kbmod/mocking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
from .headers import *
from .fits_data import *
from .fits import *
#from . import test_mocking
184 changes: 132 additions & 52 deletions src/kbmod/mocking/catalogs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import abc

import numpy as np
from astropy.time import Time
from astropy.table import QTable, vstack
from astropy.table import QTable
from .config import Config


__all__ = [
"gen_catalog",
"CatalogFactory",
"SimpleSourceCatalog",
"SimpleObjectCatalog",
"SimpleCatalog",
"SourceCatalogConfig",
"SourceCatalog",
"ObjectCatalogConfig",
"ObjectCatalog",
]


Expand All @@ -26,84 +29,161 @@ def gen_catalog(n, param_ranges, seed=None):

# conversion assumes a gaussian
if "flux" in param_ranges and "amplitude" not in param_ranges:
xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1
ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1
xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1.0
ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1.0

cat["amplitude"] = cat["flux"] / (2.0 * np.pi * xstd * ystd)

return cat



class CatalogFactory(abc.ABC):
@abc.abstractmethod
def gen_realization(self, *args, t=None, dt=None, **kwargs):
def mock(self, *args, **kwargs):
raise NotImplementedError()

def mock(self, *args, **kwargs):
return self.gen_realization(self, *args, **kwargs)

class SimpleCatalogConfig(Config):
return_copy = False
seed = None
n = 100
param_ranges = {}


class SimpleCatalog(CatalogFactory):
default_config = SimpleCatalogConfig

def __init_from_table(self, table, config=None, **kwargs):
config = self.default_config(config=config, **kwargs)
config.n = len(table)
params = {}
for col in table.keys():
params[col] = (table[col].min(), table[col].max())
config.param_ranges.update(params)
return config, table

def __init_from_config(self, config, **kwargs):
config = self.default_config(config=config, method="subset", **kwargs)
table = gen_catalog(config.n, config.param_ranges, config.seed)
return config, table

def __init_from_ranges(self, **kwargs):
param_ranges = kwargs.pop("param_ranges", None)
if param_ranges is None:
param_ranges = {k: v for k, v in kwargs.items() if k in self.default_config.param_ranges}
kwargs = {k: v for k, v in kwargs.items() if k not in self.default_config.param_ranges}

config = self.default_config(**kwargs, method="subset")
config.param_ranges.update(param_ranges)
return self.__init_from_config(config=config)

def __init__(self, table=None, config=None, **kwargs):
if table is not None:
config, table = self.__init_from_table(table, config=config, **kwargs)
elif isinstance(config, Config):
config, table = self.__init_from_config(config=config, **kwargs)
elif isinstance(config, dict) or kwargs:
config = {} if config is None else config
config, table = self.__init_from_ranges(**{**config, **kwargs})
else:
raise ValueError(
"Expected table or config, or keyword arguments of expected "
f"catalog value ranges, got:\n table={table}\n config={config} "
f"\n kwargs={kwargs}"
)

self.config = config
self.table = table
self.current = 0

class SimpleSourceCatalog(CatalogFactory):
base_param_ranges = {
"amplitude": [500, 2000],
"x_mean": [0, 4096],
"y_mean": [0, 2048],
"x_stddev": [1, 7],
"y_stddev": [1, 7],
"theta": [0, np.pi],
}
@classmethod
def from_config(cls, config, **kwargs):
config = cls.default_config(config=config, method="subset", **kwargs)
return cls(gen_catalog(config.n, config.param_ranges, config.seed), config=config)

def __init__(self, table, return_copy=False):
self.table = table
self.return_copy = return_copy
@classmethod
def from_ranges(cls, n=None, config=None, **kwargs):
config = cls.default_config(n=n, config=config, method="subset")
config.param_ranges.update(**kwargs)
return cls.from_config(config)

@classmethod
def from_params(cls, n=100, param_ranges=None):
param_ranges = {} if param_ranges is None else param_ranges
tmp = cls.base_param_ranges.copy()
tmp.update(param_ranges)
return cls(gen_catalog(n, tmp))

def gen_realization(self, *args, t=None, dt=None, **kwargs):
if self.return_copy:
def from_table(cls, table):
config = cls.default_config()
config.n = len(table)
params = {}
for col in table.keys():
params[col] = (table[col].min(), table[col].max())
config["param_ranges"] = params
return cls(table, config=config)

def mock(self):
self.current += 1
if self.config.return_copy:
return self.table.copy()
return self.table


class SimpleObjectCatalog(CatalogFactory):
base_param_ranges = {
"amplitude": [1, 100],
"x_mean": [0, 4096],
"y_mean": [0, 2048],
"vx": [500, 1000],
"vy": [500, 1000],
"stddev": [1, 1.8],
"theta": [0, np.pi],
class SourceCatalogConfig(SimpleCatalogConfig):
param_ranges = {
"amplitude": [1., 10.],
"x_mean": [0., 4096.],
"y_mean": [0., 2048.],
"x_stddev": [1., 3.],
"y_stddev": [1., 3.],
"theta": [0., np.pi],
}

def __init__(self, table, obstime=None):
self.table = table
self._realization = table.copy()

class SourceCatalog(SimpleCatalog):
default_config = SourceCatalogConfig


class ObjectCatalogConfig(SimpleCatalogConfig):
param_ranges = {
"amplitude": [0.1, 3.0],
"x_mean": [0., 4096.],
"y_mean": [0., 2048.],
"vx": [500., 1000.],
"vy": [500., 1000.],
"stddev": [0.25, 1.5],
"theta": [0., np.pi],
}


class ObjectCatalog(SimpleCatalog):
default_config = ObjectCatalogConfig

def __init__(self, table=None, obstime=None, config=None, **kwargs):
# put return_copy into kwargs to override whatever user might have
# supplied, and to guarantee the default is overriden
kwargs["return_copy"] = True
super().__init__(table=table, config=config, **kwargs)
self._realization = self.table.copy()
self.obstime = 0 if obstime is None else obstime

@classmethod
def from_params(cls, n=100, param_ranges=None):
param_ranges = {} if param_ranges is None else param_ranges
tmp = cls.base_param_ranges.copy()
tmp.update(param_ranges)
return cls(gen_catalog(n, tmp))
def reset(self):
self.current = 0
self._realization = self.table.copy()

def gen_realization(self, t=None, dt=None, **kwargs):
if t is None and dt is None:
return self._realization

dt = dt if t is None else t - self.obstime
self._realization["x_mean"] += self._realization["vx"] * dt
self._realization["y_mean"] += self._realization["vy"] * dt
self._realization["x_mean"] += self.table["vx"] * dt
self._realization["y_mean"] += self.table["vy"] * dt
return self._realization

def mock(self, n=1, **kwargs):
breakpoint()
if n == 1:
return self.gen_realization(**kwargs)
return [self.gen_realization(**kwargs).copy() for i in range(n)]
data = self.gen_realization(**kwargs)
self.current += 1
else:
data = []
for i in range(n):
data.append(self.gen_realization(**kwargs).copy())
self.current += 1

return data
99 changes: 85 additions & 14 deletions src/kbmod/mocking/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

__all__ = ["Config", "ConfigurationError"]


Expand All @@ -22,37 +24,66 @@ class attributes. Particular attributes can be overriden on an per-instance
Keyword arguments, assigned as configuration key-values.
"""

def __init__(self, config=None, **kwargs):
def __init__(self, config=None, method="default", **kwargs):
# This is a bit hacky, but it makes life a lot easier because it
# enables automatic loading of the default configuration and separation
# of default config from instance bound config
keys = list(set(dir(self.__class__)) - set(dir(Config)))

# First fill out all the defaults by copying cls attrs
self._conf = {k: getattr(self, k) for k in keys}
self._conf = {k: copy.copy(getattr(self, k)) for k in keys}

# Then override with any user-specified values
conf = config
if isinstance(config, Config):
conf = config._conf
self.update(config=config, method=method, **kwargs)

if conf is not None:
self._conf.update(config)
self._conf.update(kwargs)
@classmethod
def from_configs(cls, *args):
config = cls()
for conf in args:
config.update(config=conf, method="extend")
return config

# now just shortcut the most common dict operations
def __getitem__(self, key):
return self._conf[key]

# now just shortcut the most common dict operations
def __getattribute__(self, key):
hasconf = "_conf" in object.__getattribute__(self, "__dict__")
if hasconf:
conf = object.__getattribute__(self, "_conf")
if key in conf:
return conf[key]
return object.__getattribute__(self, key)

def __setitem__(self, key, value):
self._conf[key] = value

def __repr__(self):
res = f"{self.__class__.__name__}("
for k, v in self.items():
res += f"{k}: {v}, "
return res[:-2] + ")"

def __str__(self):
res = f"{self.__class__.__name__}("
for k, v in self.items():
res += f"{k}: {v}, "
return res[:-2] + ")"

def _repr_html_(self):
repr = f"""
<table style='tr:nth-child(even){{background-color: #dddddd;}};'>
<caption>{self.__class__.__name__}</caption>
<tr>
<th>Key</th>
<th>Value</th>
</tr>
"""
for k, v in self.items():
repr += f"<tr><td>{k}</td><td>{v}\n"
repr += "</table>"
return repr

def __len__(self):
return len(self._conf)

Expand All @@ -76,7 +107,7 @@ def __or__(self, other):
elif isinstance(other, dict):
return self.__class__(config=self._conf | other)
else:
raise TypeError("unsupported operand type(s) for |: {type(self)} " "and {type(other)}")
raise TypeError("unsupported operand type(s) for |: {type(self)}and {type(other)}")

def keys(self):
"""A set-like object providing a view on config's keys."""
Expand All @@ -90,7 +121,10 @@ def items(self):
"""A set-like object providing a view on config's items."""
return self._conf.items()

def update(self, conf=None, **kwargs):
def copy(self):
return self.__class__(config=self._conf.copy())

def update(self, config=None, method="default", **kwargs):
"""Update this config from dict/other config/iterable and
apply any explicit keyword overrides.
Expand All @@ -107,9 +141,46 @@ def update(self, conf=None, **kwargs):
for k in kwargs: this[k] = kwargs[k]
"""
if conf is not None:
self._conf.update(conf)
self._conf.update(kwargs)
# Python < 3.9 does not support set operations for dicts
# [fixme]: Update this to: other = conf | kwargs
# and remove current implementation when 3.9 gets too old. Order of
# conf and kwargs matter to correctly apply explicit overrides

# Check if both conf and kwargs are given, just conf or just
# kwargs. If none are given do nothing to comply with default
# dict behavior
if config is not None and kwargs:
other = {**config, **kwargs}
elif config is not None:
other = config
elif kwargs is not None:
other = kwargs
else:
return

# then, see if we the given config and overrides are a subset of this
# config or it's superset. Depending on the selected method then raise
# errors, ignore or extend the current config if the given config is a
# superset (or disjoint) from the current one.
subset = {k: v for k, v in other.items() if k in self._conf}
superset = {k: v for k, v in other.items() if k not in subset}

if method.lower() == "default":
if superset:
raise ConfigurationError(
"Tried setting the following fields, not a part of "
f"this configuration options: {superset}"
)
conf = other # == subset
elif method.lower() == "subset":
conf = subset
elif method.lower() == "extend":
conf = other
else:
raise ValueError("Method expected to be one of 'default', "
f"'subset' or 'extend'. Got {method} instead.")

self._conf.update(conf)

def toDict(self):
"""Return this config as a dict."""
Expand Down
Loading

0 comments on commit d797271

Please sign in to comment.