Skip to content

Commit

Permalink
Fixup.
Browse files Browse the repository at this point in the history
  • Loading branch information
DinoBektesevic committed Aug 15, 2024
1 parent 0e244aa commit abd0f6e
Show file tree
Hide file tree
Showing 10 changed files with 772 additions and 499 deletions.
32 changes: 25 additions & 7 deletions src/kbmod/mocking/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
import random

from astropy.time import Time
import astropy.units as u

__all__ = [
"IncrementObstime",
"ObstimeIterator",
]


class IncrementObstime:
default_unit = "day"
def __init__(self, start, dt):
self.start = start
self.start = Time(start)
if not isinstance(dt, u.Quantity):
dt = dt * getattr(u, self.default_unit)
self.dt = dt

def __call__(self, mut_val):
def __call__(self, header_val):
curr = self.start
self.start += self.dt
return curr
return curr.fits


class ObstimeIterator:
def __init__(self, obstimes):
self.obstimes = obstimes
def __init__(self, obstimes, **kwargs):
self.obstimes = Time(obstimes, **kwargs)
self.generator = (t for t in obstimes)

def __call__(self, mut_val):
return next(self.generator)
def __call__(self, header_val):
return Time(next(self.generator)).fits


class DitherValue:
def __init__(self, value, dither_range):
self.value = value
self.dither_range = dither_range

def __call__(self, header_val):
return self.value + random.uniform(self.dither_range)

63 changes: 41 additions & 22 deletions src/kbmod/mocking/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import numpy as np
from astropy.table import QTable
from astropy.coordinates import SkyCoord

from .config import Config


__all__ = [
"gen_catalog",
"gen_random_catalog",
"CatalogFactory",
"SimpleCatalog",
"SourceCatalogConfig",
Expand All @@ -16,7 +18,7 @@
]


def gen_catalog(n, param_ranges, seed=None):
def gen_random_catalog(n, param_ranges, seed=None):
cat = QTable()
rng = np.random.default_rng(seed)

Expand All @@ -30,9 +32,6 @@ def gen_catalog(n, param_ranges, seed=None):

# conversion assumes a gaussian
if "flux" in param_ranges and "amplitude" not in param_ranges:
xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1.0
ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1.0

cat["amplitude"] = cat["flux"] / (2.0 * np.pi * xstd * ystd)

return cat
Expand All @@ -45,7 +44,8 @@ def mock(self, *args, **kwargs):


class SimpleCatalogConfig(Config):
mode = "static"
mode = "static" # folding
kind = "pixel" # world
return_copy = False
seed = None
n = 100
Expand All @@ -56,28 +56,32 @@ class SimpleCatalog(CatalogFactory):
default_config = SimpleCatalogConfig

def __init__(self, config, table, **kwargs):
config = self.default_config(**kwargs)
self.config = config
self.config = self.default_config(config=config, **kwargs)
self.table = table
self.current = 0

@classmethod
def from_config(cls, config, **kwargs):
config = cls.default_config(config=config, **kwargs)
table = gen_catalog(config.n, config.param_ranges, config.seed)
table = gen_random_catalog(config["n"], config["param_ranges"], config["seed"])
return cls(config, table)

@classmethod
def from_defaults(cls, param_ranges=None, **kwargs):
config = cls.default_config(**kwargs)
if param_ranges is not None:
config.param_ranges.update(param_ranges)
config["param_ranges"].update(param_ranges)
return cls.from_config(config)

@classmethod
def from_table(cls, table, **kwargs):
if "x_stddev" not in table.columns:
table["x_stddev"] = table["stddev"]
if "y_stddev" not in table.columns:
table["y_stddev"] = table["stddev"]

config = cls.default_config(**kwargs)
config.n = len(table)
config["n"] = len(table)
params = {}
for col in table.keys():
params[col] = (table[col].min(), table[col].max())
Expand All @@ -86,7 +90,7 @@ def from_table(cls, table, **kwargs):

def mock(self):
self.current += 1
if self.config.return_copy:
if self.config["return_copy:"]:
return self.table.copy()
return self.table

Expand Down Expand Up @@ -127,7 +131,7 @@ def __init__(self, config, table, **kwargs):
kwargs["return_copy"] = True
super().__init__(config, table, **kwargs)
self._realization = self.table.copy()
self.mode = self.config.mode
self.mode = self.config["mode"]

@property
def mode(self):
Expand All @@ -137,8 +141,10 @@ def mode(self):
def mode(self, val):
if val == "folding":
self._gen_realization = self.fold
elif val == "progressive":
self._gen_realization = self.next
elif val == "progressive" and self.config["kind"] == "pixel":
self._gen_realization = self.next_pixel
elif val == "progressive" and self.config["kind"] == "world":
self._gen_realization = self.next_world
elif val == "static":
self._gen_realization = self.static
else:
Expand All @@ -155,25 +161,38 @@ def reset(self):
def static(self, **kwargs):
return self.table.copy()

def next(self, dt, **kwargs):
self._realization["x_mean"] = self.table["x_mean"] + self.current * self._realization["vx"] * dt
self._realization["y_mean"] = self.table["y_mean"] + self.current * self._realization["vy"] * dt
def _next(self, dt, keys):
a, va, b, vb = keys
self._realization[a] = self.table[a] + self.current * self.table[va] * dt
self._realization[b] = self.table[b] + self.current * self.table[vb] * dt
self.current += 1
return self._realization.copy()

def next_world(self, dt):
return self._next(dt, ["ra_mean", "v_ra", "dec_mean", "v_dec"])

def next_pixel(self, dt):
return self._next(dt, ["x_mean", "vx", "y_mean", "vy"])

def fold(self, t, **kwargs):
self._realization = self.table[self.table["obstime"] == t]
self.current += 1
return self._realization.copy()

def mock(self, n=1, **kwargs):
def mock(self, n=1, dt=None, t=None, wcs=None):
data = []

if self.mode == "folding":
for t in kwargs["t"]:
data.append(self.fold(t=t))
for i, ts in enumerate(t):
data.append(self.fold(t=ts))
else:
for i in range(n):
data.append(self._gen_realization(**kwargs))
data.append(self._gen_realization(dt))

if self.config["kind"] == "world":
for cat, w in zip(data, wcs):
x, y = w.world_to_pixel(SkyCoord(ra=cat["ra_mean"], dec=cat["dec_mean"], unit="deg"))
cat["x_mean"] = x
cat["y_mean"] = y

return data
126 changes: 22 additions & 104 deletions src/kbmod/mocking/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@ class ConfigurationError(Exception):


class Config:
"""Base configuration class.
"""Base class for Standardizer configuration.
Config classes that inherit from this class define configuration as their
class attributes. Particular attributes can be overriden on an per-instance
basis by providing a config overrides at initialization time.
Configs inheriting from this config support basic dictionary operations.
Not all standardizers will (can) use the same parameters so refer to their
respective documentation for a more complete list.
Parameters
----------
Expand All @@ -24,66 +21,33 @@ class attributes. Particular attributes can be overriden on an per-instance
Keyword arguments, assigned as configuration key-values.
"""

def __init__(self, config=None, method="default", **kwargs):
def __init__(self, config=None, **kwargs):
# This is a bit hacky, but it makes life a lot easier because it
# enables automatic loading of the default configuration and separation
# of default config from instance bound config
keys = list(set(dir(self.__class__)) - set(dir(Config)))

# First fill out all the defaults by copying cls attrs
self._conf = {k: copy.copy(getattr(self, k)) for k in keys}
self._conf = {k: getattr(self, k) for k in keys}

# Then override with any user-specified values
self.update(config=config, method=method, **kwargs)

@classmethod
def from_configs(cls, *args):
config = cls()
for conf in args:
config.update(config=conf, method="extend")
return config
if config is not None:
self._conf.update(config)
self._conf.update(kwargs)

# now just shortcut the most common dict operations
def __getitem__(self, key):
return self._conf[key]

# now just shortcut the most common dict operations
def __getattribute__(self, key):
hasconf = "_conf" in object.__getattribute__(self, "__dict__")
if hasconf:
conf = object.__getattribute__(self, "_conf")
if key in conf:
return conf[key]
return object.__getattribute__(self, key)

def __setitem__(self, key, value):
self._conf[key] = value

def __repr__(self):
res = f"{self.__class__.__name__}("
for k, v in self.items():
res += f"{k}: {v}, "
return res[:-2] + ")"

def __str__(self):
res = f"{self.__class__.__name__}("
for k, v in self.items():
res += f"{k}: {v}, "
return res[:-2] + ")"

def _repr_html_(self):
repr = f"""
<table style='tr:nth-child(even){{background-color: #dddddd;}};'>
<caption>{self.__class__.__name__}</caption>
<tr>
<th>Key</th>
<th>Value</th>
</tr>
"""
for k, v in self.items():
repr += f"<tr><td>{k}</td><td>{v}\n"
repr += "</table>"
return repr

def __len__(self):
return len(self._conf)

Expand All @@ -107,7 +71,7 @@ def __or__(self, other):
elif isinstance(other, dict):
return self.__class__(config=self._conf | other)
else:
raise TypeError("unsupported operand type(s) for |: {type(self)}and {type(other)}")
raise TypeError("unsupported operand type(s) for |: {type(self)} " "and {type(other)}")

def keys(self):
"""A set-like object providing a view on config's keys."""
Expand All @@ -121,68 +85,22 @@ def items(self):
"""A set-like object providing a view on config's items."""
return self._conf.items()

def copy(self):
return self.__class__(config=self._conf.copy())

def update(self, config=None, method="default", **kwargs):
"""Update this config from dict/other config/iterable and
apply any explicit keyword overrides.
A dict-like update. If ``conf`` is given and has a ``.keys()``
method, performs:
for k in conf: this[k] = conf[k]
If ``conf`` is given but lacks a ``.keys()`` method, performs:
def update(self, conf=None, **kwargs):
"""Update this config from dict/other config/iterable.
for k, v in conf: this[k] = v
A dict-like update. If ``conf`` is present and has a ``.keys()``
method, then does: ``for k in conf: this[k] = conf[k]``. If ``conf``
is present but lacks a ``.keys()`` method, then does:
``for k, v in conf: this[k] = v``.
In both cases, explicit overrides are applied at the end:
for k in kwargs: this[k] = kwargs[k]
In either case, this is followed by:
``for k in kwargs: this[k] = kwargs[k]``
"""
# Python < 3.9 does not support set operations for dicts
# [fixme]: Update this to: other = conf | kwargs
# and remove current implementation when 3.9 gets too old. Order of
# conf and kwargs matter to correctly apply explicit overrides

# Check if both conf and kwargs are given, just conf or just
# kwargs. If none are given do nothing to comply with default
# dict behavior
if config is not None and kwargs:
other = {**config, **kwargs}
elif config is not None:
other = config
elif kwargs is not None:
other = kwargs
else:
return

# then, see if we the given config and overrides are a subset of this
# config or it's superset. Depending on the selected method then raise
# errors, ignore or extend the current config if the given config is a
# superset (or disjoint) from the current one.
subset = {k: v for k, v in other.items() if k in self._conf}
superset = {k: v for k, v in other.items() if k not in subset}

if method.lower() == "default":
if superset:
raise ConfigurationError(
"Tried setting the following fields, not a part of "
f"this configuration options: {superset}"
)
conf = other # == subset
elif method.lower() == "subset":
conf = subset
elif method.lower() == "extend":
conf = other
else:
raise ValueError(
"Method expected to be one of 'default', " f"'subset' or 'extend'. Got {method} instead."
)

self._conf.update(conf)
if conf is not None:
self._conf.update(conf)
self._conf.update(kwargs)

def toDict(self):
"""Return this config as a dict."""
return self._conf

Loading

0 comments on commit abd0f6e

Please sign in to comment.