Skip to content

Commit

Permalink
Merge pull request #4 from xuyuon/98-moving-naming-tracking-into-jim-…
Browse files Browse the repository at this point in the history
…class-from-prior-class

98 moving naming tracking into jim class from prior class
  • Loading branch information
xuyuon authored Aug 2, 2024
2 parents 7a7796c + 6374772 commit 24f435c
Show file tree
Hide file tree
Showing 11 changed files with 1,056 additions and 323 deletions.
39 changes: 26 additions & 13 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Jim(object):
# Name of parameters to sample from
sample_transforms: list[BijectiveTransform]
likelihood_transforms: list[NtoMTransform]
parameter_names: list[str]
sampler: Sampler

def __init__(
Expand All @@ -37,11 +38,16 @@ def __init__(

self.sample_transforms = sample_transforms
self.likelihood_transforms = likelihood_transforms
self.parameter_names = prior.parameter_names

if len(sample_transforms) == 0:
print(
"No sample transforms provided. Using prior parameters as sampling parameters"
)
else:
print("Using sample transforms")
for transform in sample_transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print(
Expand All @@ -64,7 +70,7 @@ def __init__(
self.prior.n_dim, num_layers, hidden_size, num_bins, subkey
)

self.Sampler = Sampler(
self.sampler = Sampler(
self.prior.n_dim,
rng_key,
None, # type: ignore
Expand All @@ -88,7 +94,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
def posterior(self, params: Float[Array, " n_dim"], data: dict):
named_params = self.add_name(params)
transform_jacobian = 0.0
for transform in self.sample_transforms:
for transform in reversed(self.sample_transforms):
named_params, jacobian = transform.inverse(named_params)
transform_jacobian += jacobian
prior = self.prior.log_prob(named_params) + transform_jacobian
Expand All @@ -98,9 +104,11 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict):

def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.prior.sample(key, self.Sampler.n_chains)
initial_guess_named = self.prior.sample(key, self.sampler.n_chains)
for transform in self.sample_transforms:
initial_guess_named = jax.vmap(transform.forward)(initial_guess_named)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
self.Sampler.sample(initial_guess, None) # type: ignore
self.sampler.sample(initial_guess, None) # type: ignore

def maximize_likelihood(
self,
Expand Down Expand Up @@ -133,22 +141,24 @@ def print_summary(self, transform: bool = True):
"""

train_summary = self.Sampler.get_sampler_state(training=True)
production_summary = self.Sampler.get_sampler_state(training=False)
train_summary = self.sampler.get_sampler_state(training=True)
production_summary = self.sampler.get_sampler_state(training=False)

training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T
training_chain = self.prior.add_name(training_chain)
training_chain = self.add_name(training_chain)
if transform:
training_chain = self.prior.transform(training_chain)
for sample_transform in self.sample_transforms:
training_chain = sample_transform.backward(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
training_loss = train_summary["loss_vals"]

production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T
production_chain = self.prior.add_name(production_chain)
production_chain = self.add_name(production_chain)
if transform:
production_chain = self.prior.transform(production_chain)
for sample_transform in self.sample_transforms:
production_chain = sample_transform.backward(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]
Expand Down Expand Up @@ -200,11 +210,14 @@ def get_samples(self, training: bool = False) -> dict:
"""
if training:
chains = self.Sampler.get_sampler_state(training=True)["chains"]
chains = self.sampler.get_sampler_state(training=True)["chains"]
else:
chains = self.Sampler.get_sampler_state(training=False)["chains"]
chains = self.sampler.get_sampler_state(training=False)["chains"]

chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1)))
chains = chains.transpose(2, 0, 1)
chains = self.add_name(chains)
for sample_transform in self.sample_transforms:
chains = sample_transform.backward(chains)
return chains

def plot(self):
Expand Down
219 changes: 0 additions & 219 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,227 +468,8 @@ def __init__(
)


@jaxtyped(typechecker=typechecker)
class UniformInComponentsChirpMassPrior(PowerLawPrior):
"""
A prior in the range [xmin, xmax) for chirp mass which assumes the
component masses to be uniformly distributed.
p(M_c) ~ M_c
"""

def __repr__(self):
return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})"

def __init__(self, xmin: float, xmax: float):
super().__init__(xmin, xmax, 1.0, ["M_c"])


def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]:
if prior.composite:
if isinstance(prior.base_prior, list):
for subprior in prior.base_prior:
output = trace_prior_parent(subprior, output)
elif isinstance(prior.base_prior, Prior):
output = trace_prior_parent(prior.base_prior, output)
else:
output.append(prior)

return output


# ====================== Things below may need rework ======================


# @jaxtyped(typechecker=typechecker)
# class AlignedSpin(Prior):
# """
# Prior distribution for the aligned (z) component of the spin.

# This assume the prior distribution on the spin magnitude to be uniform in [0, amax]
# with its orientation uniform on a sphere

# p(chi) = -log(|chi| / amax) / 2 / amax

# This is useful when comparing results between an aligned-spin run and
# a precessing spin run.

# See (A7) of https://arxiv.org/abs/1805.10457.
# """

# amax: Float = 0.99
# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000))
# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000))

# def __repr__(self):
# return f"Alignedspin(amax={self.amax}, naming={self.naming})"

# def __init__(
# self,
# amax: Float,
# naming: list[str],
# transforms: dict[str, tuple[str, Callable]] = {},
# **kwargs,
# ):
# super().__init__(naming, transforms)
# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions"
# self.amax = amax

# # build the interpolation table for the ppf of the one-sided distribution
# chi_axis = jnp.linspace(1e-31, self.amax, num=1000)
# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax
# self.chi_axis = chi_axis
# self.cdf_vals = cdf_vals

# @property
# def xmin(self):
# return -self.amax

# @property
# def xmax(self):
# return self.amax

# def sample(
# self, rng_key: PRNGKeyArray, n_samples: int
# ) -> dict[str, Float[Array, " n_samples"]]:
# """
# Sample from the Alignedspin distribution.

# for chi > 0;
# p(chi) = -log(chi / amax) / amax # halved normalization constant
# cdf(chi) = -chi * (log(chi / amax) - 1) / amax

# Since there is a pole at chi=0, we will sample with the following steps
# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise
# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q)
# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5)
# 3. Map the quantile to chi via the ppf by checking against the table
# built during the initialization
# 4. add back the sign

# Parameters
# ----------
# rng_key : PRNGKeyArray
# A random key to use for sampling.
# n_samples : int
# The number of samples to draw.

# Returns
# -------
# samples : dict
# Samples from the distribution. The keys are the names of the parameters.

# """
# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0)
# # 1. calculate the sign of chi from the q_samples
# sign_samples = jnp.where(
# q_samples >= 0.5,
# jnp.zeros_like(q_samples) + 1.0,
# jnp.zeros_like(q_samples) - 1.0,
# )
# # 2. remap q_samples
# q_samples = jnp.where(
# q_samples >= 0.5,
# 2 * (q_samples - 0.5),
# 2 * (0.5 - q_samples),
# )
# # 3. map the quantile to chi via interpolation
# samples = jnp.interp(
# q_samples,
# self.cdf_vals,
# self.chi_axis,
# )
# # 4. add back the sign
# samples *= sign_samples

# return self.add_name(samples[None])

# def log_prob(self, x: dict[str, Float]) -> Float:
# variable = x[self.naming[0]]
# log_p = jnp.where(
# (variable >= self.amax) | (variable <= -self.amax),
# jnp.zeros_like(variable) - jnp.inf,
# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax),
# )
# return log_p


# @jaxtyped(typechecker=typechecker)
# class EarthFrame(Prior):
# """
# Prior distribution for sky location in Earth frame.
# """

# ifos: list = field(default_factory=list)
# gmst: float = 0.0
# delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3))

# def __repr__(self):
# return f"EarthFrame(naming={self.naming})"

# def __init__(self, gps: Float, ifos: list, **kwargs):
# self.naming = ["zenith", "azimuth"]
# if len(ifos) < 2:
# return ValueError(
# "At least two detectors are needed to define the Earth frame"
# )
# elif isinstance(ifos[0], str):
# self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]]
# elif isinstance(ifos[0], GroundBased2G):
# self.ifos = ifos[:1]
# else:
# return ValueError(
# "ifos should be a list of detector names or GroundBased2G objects"
# )
# self.gmst = float(
# Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
# )
# self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex

# self.transforms = {
# "azimuth": (
# "ra",
# lambda params: zenith_azimuth_to_ra_dec(
# params["zenith"],
# params["azimuth"],
# gmst=self.gmst,
# delta_x=self.delta_x,
# )[0],
# ),
# "zenith": (
# "dec",
# lambda params: zenith_azimuth_to_ra_dec(
# params["zenith"],
# params["azimuth"],
# gmst=self.gmst,
# delta_x=self.delta_x,
# )[1],
# ),
# }

# def sample(
# self, rng_key: PRNGKeyArray, n_samples: int
# ) -> dict[str, Float[Array, " n_samples"]]:
# rng_keys = jax.random.split(rng_key, 2)
# zenith = jnp.arccos(
# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)
# )
# azimuth = jax.random.uniform(
# rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi
# )
# return self.add_name(jnp.stack([zenith, azimuth], axis=1).T)

# def log_prob(self, x: dict[str, Float]) -> Float:
# zenith = x["zenith"]
# azimuth = x["azimuth"]
# output = jnp.where(
# (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0),
# jnp.zeros_like(0) - jnp.inf,
# jnp.zeros_like(0),
# )
# return output + jnp.log(jnp.sin(zenith))


# @jaxtyped(typechecker=typechecker)
# class Exponential(Prior):
# """
Expand Down
2 changes: 1 addition & 1 deletion src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jimgw.base import LikelihoodBase
from jimgw.prior import Prior
from jimgw.single_event.detector import Detector
from jimgw.single_event.utils import log_i0
from jimgw.utils import log_i0
from jimgw.single_event.waveform import Waveform


Expand Down
Loading

0 comments on commit 24f435c

Please sign in to comment.