diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 3186c2c4..b6c91582 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -22,6 +22,7 @@ class Jim(object): # Name of parameters to sample from sample_transforms: list[BijectiveTransform] likelihood_transforms: list[NtoMTransform] + parameter_names: list[str] sampler: Sampler def __init__( @@ -37,11 +38,16 @@ def __init__( self.sample_transforms = sample_transforms self.likelihood_transforms = likelihood_transforms + self.parameter_names = prior.parameter_names if len(sample_transforms) == 0: print( "No sample transforms provided. Using prior parameters as sampling parameters" ) + else: + print("Using sample transforms") + for transform in sample_transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) if len(likelihood_transforms) == 0: print( @@ -64,7 +70,7 @@ def __init__( self.prior.n_dim, num_layers, hidden_size, num_bins, subkey ) - self.Sampler = Sampler( + self.sampler = Sampler( self.prior.n_dim, rng_key, None, # type: ignore @@ -88,7 +94,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) transform_jacobian = 0.0 - for transform in self.sample_transforms: + for transform in reversed(self.sample_transforms): named_params, jacobian = transform.inverse(named_params) transform_jacobian += jacobian prior = self.prior.log_prob(named_params) + transform_jacobian @@ -98,9 +104,11 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: - initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) + initial_guess_named = self.prior.sample(key, self.sampler.n_chains) + for transform in self.sample_transforms: + initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T - self.Sampler.sample(initial_guess, None) # type: ignore + self.sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( self, @@ -133,22 +141,24 @@ def print_summary(self, transform: bool = True): """ - train_summary = self.Sampler.get_sampler_state(training=True) - production_summary = self.Sampler.get_sampler_state(training=False) + train_summary = self.sampler.get_sampler_state(training=True) + production_summary = self.sampler.get_sampler_state(training=False) training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T - training_chain = self.prior.add_name(training_chain) + training_chain = self.add_name(training_chain) if transform: - training_chain = self.prior.transform(training_chain) + for sample_transform in self.sample_transforms: + training_chain = sample_transform.backward(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T - production_chain = self.prior.add_name(production_chain) + production_chain = self.add_name(production_chain) if transform: - production_chain = self.prior.transform(production_chain) + for sample_transform in self.sample_transforms: + production_chain = sample_transform.backward(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -200,11 +210,14 @@ def get_samples(self, training: bool = False) -> dict: """ if training: - chains = self.Sampler.get_sampler_state(training=True)["chains"] + chains = self.sampler.get_sampler_state(training=True)["chains"] else: - chains = self.Sampler.get_sampler_state(training=False)["chains"] + chains = self.sampler.get_sampler_state(training=False)["chains"] - chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1))) + chains = chains.transpose(2, 0, 1) + chains = self.add_name(chains) + for sample_transform in self.sample_transforms: + chains = sample_transform.backward(chains) return chains def plot(self): diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 7cfe04e2..0df66ecd 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -468,227 +468,8 @@ def __init__( ) -@jaxtyped(typechecker=typechecker) -class UniformInComponentsChirpMassPrior(PowerLawPrior): - """ - A prior in the range [xmin, xmax) for chirp mass which assumes the - component masses to be uniformly distributed. - - p(M_c) ~ M_c - """ - - def __repr__(self): - return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" - - def __init__(self, xmin: float, xmax: float): - super().__init__(xmin, xmax, 1.0, ["M_c"]) - - -def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: - if prior.composite: - if isinstance(prior.base_prior, list): - for subprior in prior.base_prior: - output = trace_prior_parent(subprior, output) - elif isinstance(prior.base_prior, Prior): - output = trace_prior_parent(prior.base_prior, output) - else: - output.append(prior) - - return output - - # ====================== Things below may need rework ====================== - -# @jaxtyped(typechecker=typechecker) -# class AlignedSpin(Prior): -# """ -# Prior distribution for the aligned (z) component of the spin. - -# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] -# with its orientation uniform on a sphere - -# p(chi) = -log(|chi| / amax) / 2 / amax - -# This is useful when comparing results between an aligned-spin run and -# a precessing spin run. - -# See (A7) of https://arxiv.org/abs/1805.10457. -# """ - -# amax: Float = 0.99 -# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) -# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - -# def __repr__(self): -# return f"Alignedspin(amax={self.amax}, naming={self.naming})" - -# def __init__( -# self, -# amax: Float, -# naming: list[str], -# transforms: dict[str, tuple[str, Callable]] = {}, -# **kwargs, -# ): -# super().__init__(naming, transforms) -# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" -# self.amax = amax - -# # build the interpolation table for the ppf of the one-sided distribution -# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) -# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax -# self.chi_axis = chi_axis -# self.cdf_vals = cdf_vals - -# @property -# def xmin(self): -# return -self.amax - -# @property -# def xmax(self): -# return self.amax - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# """ -# Sample from the Alignedspin distribution. - -# for chi > 0; -# p(chi) = -log(chi / amax) / amax # halved normalization constant -# cdf(chi) = -chi * (log(chi / amax) - 1) / amax - -# Since there is a pole at chi=0, we will sample with the following steps -# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise -# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) -# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) -# 3. Map the quantile to chi via the ppf by checking against the table -# built during the initialization -# 4. add back the sign - -# Parameters -# ---------- -# rng_key : PRNGKeyArray -# A random key to use for sampling. -# n_samples : int -# The number of samples to draw. - -# Returns -# ------- -# samples : dict -# Samples from the distribution. The keys are the names of the parameters. - -# """ -# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) -# # 1. calculate the sign of chi from the q_samples -# sign_samples = jnp.where( -# q_samples >= 0.5, -# jnp.zeros_like(q_samples) + 1.0, -# jnp.zeros_like(q_samples) - 1.0, -# ) -# # 2. remap q_samples -# q_samples = jnp.where( -# q_samples >= 0.5, -# 2 * (q_samples - 0.5), -# 2 * (0.5 - q_samples), -# ) -# # 3. map the quantile to chi via interpolation -# samples = jnp.interp( -# q_samples, -# self.cdf_vals, -# self.chi_axis, -# ) -# # 4. add back the sign -# samples *= sign_samples - -# return self.add_name(samples[None]) - -# def log_prob(self, x: dict[str, Float]) -> Float: -# variable = x[self.naming[0]] -# log_p = jnp.where( -# (variable >= self.amax) | (variable <= -self.amax), -# jnp.zeros_like(variable) - jnp.inf, -# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), -# ) -# return log_p - - -# @jaxtyped(typechecker=typechecker) -# class EarthFrame(Prior): -# """ -# Prior distribution for sky location in Earth frame. -# """ - -# ifos: list = field(default_factory=list) -# gmst: float = 0.0 -# delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) - -# def __repr__(self): -# return f"EarthFrame(naming={self.naming})" - -# def __init__(self, gps: Float, ifos: list, **kwargs): -# self.naming = ["zenith", "azimuth"] -# if len(ifos) < 2: -# return ValueError( -# "At least two detectors are needed to define the Earth frame" -# ) -# elif isinstance(ifos[0], str): -# self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] -# elif isinstance(ifos[0], GroundBased2G): -# self.ifos = ifos[:1] -# else: -# return ValueError( -# "ifos should be a list of detector names or GroundBased2G objects" -# ) -# self.gmst = float( -# Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad -# ) -# self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - -# self.transforms = { -# "azimuth": ( -# "ra", -# lambda params: zenith_azimuth_to_ra_dec( -# params["zenith"], -# params["azimuth"], -# gmst=self.gmst, -# delta_x=self.delta_x, -# )[0], -# ), -# "zenith": ( -# "dec", -# lambda params: zenith_azimuth_to_ra_dec( -# params["zenith"], -# params["azimuth"], -# gmst=self.gmst, -# delta_x=self.delta_x, -# )[1], -# ), -# } - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# rng_keys = jax.random.split(rng_key, 2) -# zenith = jnp.arccos( -# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) -# ) -# azimuth = jax.random.uniform( -# rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi -# ) -# return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) - -# def log_prob(self, x: dict[str, Float]) -> Float: -# zenith = x["zenith"] -# azimuth = x["azimuth"] -# output = jnp.where( -# (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), -# jnp.zeros_like(0) - jnp.inf, -# jnp.zeros_like(0), -# ) -# return output + jnp.log(jnp.sin(zenith)) - - # @jaxtyped(typechecker=typechecker) # class Exponential(Prior): # """ diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index f10aeed1..ce2e8f0e 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -12,7 +12,7 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior from jimgw.single_event.detector import Detector -from jimgw.single_event.utils import log_i0 +from jimgw.utils import log_i0 from jimgw.single_event.waveform import Waveform diff --git a/src/jimgw/single_event/prior.py b/src/jimgw/single_event/prior.py new file mode 100644 index 00000000..51a754eb --- /dev/null +++ b/src/jimgw/single_event/prior.py @@ -0,0 +1,179 @@ +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import jaxtyped + +from jimgw.prior import ( + Prior, + CombinePrior, + UniformPrior, + PowerLawPrior, + SinePrior, +) + + +@jaxtyped(typechecker=typechecker) +class UniformSpherePrior(CombinePrior): + + def __repr__(self): + return f"UniformSpherePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str], **kwargs): + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" + self.parameter_names = [ + f"{self.parameter_names[0]}_mag", + f"{self.parameter_names[0]}_theta", + f"{self.parameter_names[0]}_phi", + ] + super().__init__( + [ + UniformPrior(0.0, 1.0, [self.parameter_names[0]]), + SinePrior([self.parameter_names[1]]), + UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + ] + ) + + +@jaxtyped(typechecker=typechecker) +class UniformComponentChirpMassPrior(PowerLawPrior): + """ + A prior in the range [xmin, xmax) for chirp mass which assumes the + component masses to be uniformly distributed. + + p(M_c) ~ M_c + """ + + def __repr__(self): + return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float): + super().__init__(xmin, xmax, 1.0, ["M_c"]) + + +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output + + +# ====================== Things below may need rework ====================== + + +# @jaxtyped(typechecker=typechecker) +# class AlignedSpin(Prior): +# """ +# Prior distribution for the aligned (z) component of the spin. + +# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] +# with its orientation uniform on a sphere + +# p(chi) = -log(|chi| / amax) / 2 / amax + +# This is useful when comparing results between an aligned-spin run and +# a precessing spin run. + +# See (A7) of https://arxiv.org/abs/1805.10457. +# """ + +# amax: Float = 0.99 +# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) +# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + +# def __repr__(self): +# return f"Alignedspin(amax={self.amax}, naming={self.naming})" + +# def __init__( +# self, +# amax: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" +# self.amax = amax + +# # build the interpolation table for the ppf of the one-sided distribution +# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) +# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax +# self.chi_axis = chi_axis +# self.cdf_vals = cdf_vals + +# @property +# def xmin(self): +# return -self.amax + +# @property +# def xmax(self): +# return self.amax + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from the Alignedspin distribution. + +# for chi > 0; +# p(chi) = -log(chi / amax) / amax # halved normalization constant +# cdf(chi) = -chi * (log(chi / amax) - 1) / amax + +# Since there is a pole at chi=0, we will sample with the following steps +# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise +# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) +# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) +# 3. Map the quantile to chi via the ppf by checking against the table +# built during the initialization +# 4. add back the sign + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# # 1. calculate the sign of chi from the q_samples +# sign_samples = jnp.where( +# q_samples >= 0.5, +# jnp.zeros_like(q_samples) + 1.0, +# jnp.zeros_like(q_samples) - 1.0, +# ) +# # 2. remap q_samples +# q_samples = jnp.where( +# q_samples >= 0.5, +# 2 * (q_samples - 0.5), +# 2 * (0.5 - q_samples), +# ) +# # 3. map the quantile to chi via interpolation +# samples = jnp.interp( +# q_samples, +# self.cdf_vals, +# self.chi_axis, +# ) +# # 4. add back the sign +# samples *= sign_samples + +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_p = jnp.where( +# (variable >= self.amax) | (variable <= -self.amax), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), +# ) +# return log_p diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py new file mode 100644 index 00000000..c3e77846 --- /dev/null +++ b/src/jimgw/single_event/transforms.py @@ -0,0 +1,231 @@ +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped +from astropy.time import Time + +from jimgw.single_event.detector import GroundBased2G +from jimgw.transforms import BijectiveTransform, NtoNTransform +from jimgw.single_event.utils import ( + m1_m2_to_Mc_q, + Mc_q_to_m1_m2, + m1_m2_to_Mc_eta, + Mc_eta_to_m1_m2, + q_to_eta, + eta_to_q, + ra_dec_to_zenith_azimuth, + zenith_azimuth_to_ra_dec, + euler_rotation, + spin_to_cartesian_spin, +) + + +@jaxtyped(typechecker=typechecker) +class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to component masses + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + assert ( + "m_1" in name_mapping[0] + and "m_2" in name_mapping[0] + and "M_c" in name_mapping[1] + and "q" in name_mapping[1] + ) + + def named_transform(x): + Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) + return {"M_c": Mc, "q": q} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} + + self.inverse_transform_func = named_inverse_transform + + +@jaxtyped(typechecker=typechecker) +class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform mass ratio to symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + assert ( + "m_1" in name_mapping[0] + and "m_2" in name_mapping[0] + and "M_c" in name_mapping[1] + and "eta" in name_mapping[1] + ) + + def named_transform(x): + Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) + return {"M_c": Mc, "eta": eta} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} + + self.inverse_transform_func = named_inverse_transform + + +@jaxtyped(typechecker=typechecker) +class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform mass ratio to symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0] + + self.transform_func = lambda x: {"eta": q_to_eta(x["q"])} + self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])} + + +@jaxtyped(typechecker=typechecker) +class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): + """ + Transform sky frame to detector frame sky position + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + rotation: Float[Array, " 3 3"] + rotation_inv: Float[Array, " 3 3"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifos: GroundBased2G, + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + delta_x = ifos[0].vertex - ifos[1].vertex + self.rotation = euler_rotation(delta_x) + self.rotation_inv = jnp.linalg.inv(self.rotation) + + assert ( + "ra" in name_mapping[0] + and "dec" in name_mapping[0] + and "zenith" in name_mapping[1] + and "azimuth" in name_mapping[1] + ) + + def named_transform(x): + zenith, azimuth = ra_dec_to_zenith_azimuth( + x["ra"], x["dec"], self.gmst, self.rotation + ) + return {"zenith": zenith, "azimuth": azimuth} + + self.transform_func = named_transform + + def named_inverse_transform(x): + ra, dec = zenith_azimuth_to_ra_dec( + x["zenith"], x["azimuth"], self.gmst, self.rotation_inv + ) + return {"ra": ra, "dec": dec} + + self.inverse_transform_func = named_inverse_transform + + +@jaxtyped(typechecker=typechecker) +class SpinToCartesianSpinTransform(NtoNTransform): + """ + Spin to Cartesian spin transformation + """ + + freq_ref: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + freq_ref: Float, + ): + super().__init__(name_mapping) + + self.freq_ref = freq_ref + + assert ( + "theta_jn" in name_mapping[0] + and "phi_jl" in name_mapping[0] + and "theta_1" in name_mapping[0] + and "theta_2" in name_mapping[0] + and "phi_12" in name_mapping[0] + and "a_1" in name_mapping[0] + and "a_2" in name_mapping[0] + and "iota" in name_mapping[1] + and "s1_x" in name_mapping[1] + and "s1_y" in name_mapping[1] + and "s1_z" in name_mapping[1] + and "s2_x" in name_mapping[1] + and "s2_y" in name_mapping[1] + and "s2_z" in name_mapping[1] + ) + + def named_transform(x): + iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( + x["theta_jn"], + x["phi_jl"], + x["theta_1"], + x["theta_2"], + x["phi_12"], + x["a_1"], + x["a_2"], + x["M_c"], + x["q"], + self.freq_ref, + x["phase_c"], + ) + return { + "iota": iota, + "s1_x": s1x, + "s1_y": s1y, + "s1_z": s1z, + "s2_x": s2x, + "s2_y": s2y, + "s2_z": s2z, + } + + self.transform_func = named_transform diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 3f0decce..a15bd7bf 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -1,8 +1,9 @@ import jax.numpy as jnp from jax.scipy.integrate import trapezoid -from jax.scipy.special import i0e from jaxtyping import Array, Float +from jimgw.constants import Msun + def inner_product( h1: Float[Array, " n_sample"], @@ -37,7 +38,7 @@ def inner_product( return 4.0 * jnp.real(trapezoid(integrand, dx=df)) -def m1m2_to_Mq(m1: Float, m2: Float): +def m1_m2_to_M_q(m1: Float, m2: Float): """ Transforming the primary mass m1 and secondary mass m2 to the Total mass M and mass ratio q. @@ -56,12 +57,12 @@ def m1m2_to_Mq(m1: Float, m2: Float): q : Float Mass ratio. """ - M_tot = jnp.log(m1 + m2) - q = jnp.log(m2 / m1) - jnp.log(1 - m2 / m1) + M_tot = m1 + m2 + q = m2 / m1 return M_tot, q -def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): +def M_q_to_m1_m2(M_tot: Float, q: Float): """ Transforming the Total mass M and mass ratio q to the primary mass m1 and secondary mass m2. @@ -80,21 +81,45 @@ def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): m2 : Float Secondary mass. """ - M_tot = jnp.exp(trans_M_tot) - q = 1.0 / (1 + jnp.exp(-trans_q)) m1 = M_tot / (1 + q) m2 = m1 * q return m1, m2 -def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: +def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c + and mass ratio q. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + """ + M_tot = m1 + m2 + eta = m1 * m2 / M_tot**2 + M_c = M_tot * eta ** (3.0 / 5) + q = m2 / m1 + return M_c, q + + +def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: """ - Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and + Transforming the chirp mass M_c and mass ratio q to the primary mass m1 and secondary mass m2. Parameters ---------- - Mc : Float + M_c : Float Chirp mass. q : Float Mass ratio. @@ -107,36 +132,148 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: Secondary mass. """ eta = q / (1 + q) ** 2 - M_tot = Mc / eta ** (3.0 / 5) + M_tot = M_c / eta ** (3.0 / 5) m1 = M_tot / (1 + q) m2 = m1 * q return m1, m2 -def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: +def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: """ - Transforming the right ascension ra and declination dec to the polar angle - theta and azimuthal angle phi. + Transforming the primary mass m1 and secondary mass m2 to the total mass M + and symmetric mass ratio eta. Parameters ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. + m1 : Float + Primary mass. + m2 : Float + Secondary mass. Returns ------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. + M : Float + Total mass. + eta : Float + Symmetric mass ratio. """ - phi = ra - gmst - theta = jnp.pi / 2 - dec - return theta, phi + M_tot = m1 + m2 + eta = m1 * m2 / M_tot**2 + return M_tot, eta + + +def M_eta_to_m1_m2(M_tot: Float, eta: Float) -> tuple[Float, Float]: + """ + Transforming the total mass M and symmetric mass ratio eta to the primary mass m1 + and secondary mass m2. + + Parameters + ---------- + M : Float + Total mass. + eta : Float + Symmetric mass ratio. + + Returns + ------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + """ + m1 = M_tot * (1 + jnp.sqrt(1 - 4 * eta)) / 2 + m2 = M_tot * (1 - jnp.sqrt(1 - 4 * eta)) / 2 + return m1, m2 + + +def m1_m2_to_Mc_eta(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c + and symmetric mass ratio eta. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M_c : Float + Chirp mass. + eta : Float + Symmetric mass ratio. + """ + M = m1 + m2 + eta = m1 * m2 / M**2 + M_c = M * eta ** (3.0 / 5) + return M_c, eta + + +def Mc_eta_to_m1_m2(M_c: Float, eta: Float) -> tuple[Float, Float]: + """ + Transforming the chirp mass M_c and symmetric mass ratio eta to the primary mass m1 + and secondary mass m2. + + Parameters + ---------- + M_c : Float + Chirp mass. + eta : Float + Symmetric mass ratio. + + Returns + ------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + """ + M = M_c / eta ** (3.0 / 5) + m1 = M * (1 + jnp.sqrt(1 - 4 * eta)) / 2 + m2 = M * (1 - jnp.sqrt(1 - 4 * eta)) / 2 + return m1, m2 + + +def q_to_eta(q: Float) -> Float: + """ + Transforming the chirp mass M_c and mass ratio q to the symmetric mass ratio eta. + + Parameters + ---------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + + Returns + ------- + eta : Float + Symmetric mass ratio. + """ + eta = q / (1 + q) ** 2 + return eta + + +def eta_to_q(eta: Float) -> Float: + """ + Transforming the symmetric mass ratio eta to the mass ratio q. + + Copied and modified from bilby/gw/conversion.py + + Parameters + ---------- + eta : Float + Symmetric mass ratio. + + Returns + ------- + q : Float + Mass ratio. + """ + temp = 1 / eta / 2 - 1 + return temp - (temp**2 - 1) ** 0.5 def euler_rotation(delta_x: Float[Array, " 3"]): @@ -149,11 +286,10 @@ def euler_rotation(delta_x: Float[Array, " 3"]): Copied and modified from bilby-cython/geometry.pyx """ - norm = jnp.power( - delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5 - ) + norm = jnp.linalg.vector_norm(delta_x) + cos_beta = delta_x[2] / norm - sin_beta = jnp.power(1 - cos_beta**2, 0.5) + sin_beta = jnp.sqrt(1 - cos_beta**2) alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0]) gamma = jnp.atan2(delta_x[1], delta_x[0]) @@ -182,8 +318,8 @@ def euler_rotation(delta_x: Float[Array, " 3"]): return rotation -def zenith_azimuth_to_theta_phi( - zenith: Float, azimuth: Float, delta_x: Float[Array, " 3"] +def angle_rotation( + zenith: Float, azimuth: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame. @@ -196,8 +332,8 @@ def zenith_azimuth_to_theta_phi( Zenith angle. azimuth : Float Azimuthal angle. - delta_x : Float - The vector pointing from the first detector to the second detector. + rotation : Float[Array, " 3 3"] + The rotation matrix. Returns ------- @@ -211,8 +347,6 @@ def zenith_azimuth_to_theta_phi( sin_zenith = jnp.sin(zenith) cos_zenith = jnp.cos(zenith) - rotation = euler_rotation(delta_x) - theta = jnp.acos( rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth @@ -228,7 +362,7 @@ def zenith_azimuth_to_theta_phi( + rotation[0][2] * cos_zenith, ) + 2 * jnp.pi, - (2 * jnp.pi), + 2 * jnp.pi, ) return theta, phi @@ -255,11 +389,170 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F """ ra = phi + gmst dec = jnp.pi / 2 - theta + ra = ra % (2 * jnp.pi) return ra, dec +def spin_to_cartesian_spin( + thetaJN: Float, + phiJL: Float, + theta1: Float, + theta2: Float, + phi12: Float, + chi1: Float, + chi2: Float, + M_c: Float, + q: Float, + fRef: Float, + phiRef: Float, +) -> tuple[Float, Float, Float, Float, Float, Float, Float]: + """ + Transforming the spin parameters + + The code is based on the approach used in LALsimulation: + https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html + + Parameters: + ------- + thetaJN: Float + Zenith angle between the total angular momentum and the line of sight + phiJL: Float + Difference between total and orbital angular momentum azimuthal angles + theta1: Float + Zenith angle between the spin and orbital angular momenta for the primary object + theta2: Float + Zenith angle between the spin and orbital angular momenta for the secondary object + phi12: Float + Difference between the azimuthal angles of the individual spin vector projections + onto the orbital plane + chi1: Float + Primary object aligned spin: + chi2: Float + Secondary object aligned spin: + M_c: Float + The chirp mass + eta: Float + The symmetric mass ratio + fRef: Float + The reference frequency + phiRef: Float + Binary phase at a reference frequency + + Returns: + ------- + iota: Float + Zenith angle between the orbital angular momentum and the line of sight + S1x: Float + The x-component of the primary spin + S1y: Float + The y-component of the primary spin + S1z: Float + The z-component of the primary spin + S2x: Float + The x-component of the secondary spin + S2y: Float + The y-component of the secondary spin + S2z: Float + The z-component of the secondary spin + """ + + def rotate_y(angle, vec): + """ + Rotate the vector (x, y, z) about y-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + def rotate_z(angle, vec): + """ + Rotate the vector (x, y, z) about z-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + LNh = jnp.array([0.0, 0.0, 1.0]) + + s1hat = jnp.array( + [ + jnp.sin(theta1) * jnp.cos(phiRef), + jnp.sin(theta1) * jnp.sin(phiRef), + jnp.cos(theta1), + ] + ) + s2hat = jnp.array( + [ + jnp.sin(theta2) * jnp.cos(phi12 + phiRef), + jnp.sin(theta2) * jnp.sin(phi12 + phiRef), + jnp.cos(theta2), + ] + ) + + m1, m2 = Mc_q_to_m1_m2(M_c, q) + eta = q / (1 + q) ** 2 + v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) + + Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) + s1 = m1 * m1 * chi1 * s1hat + s2 = m2 * m2 * chi2 * s2hat + J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) + + Jhat = J / jnp.linalg.norm(J) + theta0 = jnp.arccos(Jhat[2]) + phi0 = jnp.arctan2(Jhat[1], Jhat[0]) + + # Rotation 1: + s1hat = rotate_z(-phi0, s1hat) + s2hat = rotate_z(-phi0, s2hat) + + # Rotation 2: + LNh = rotate_y(-theta0, LNh) + s1hat = rotate_y(-theta0, s1hat) + s2hat = rotate_y(-theta0, s2hat) + + # Rotation 3: + LNh = rotate_z(phiJL - jnp.pi, LNh) + s1hat = rotate_z(phiJL - jnp.pi, s1hat) + s2hat = rotate_z(phiJL - jnp.pi, s2hat) + + # Compute iota + N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) + iota = jnp.arccos(jnp.dot(N, LNh)) + + thetaLJ = jnp.arccos(LNh[2]) + phiL = jnp.arctan2(LNh[1], LNh[0]) + + # Rotation 4: + s1hat = rotate_z(-phiL, s1hat) + s2hat = rotate_z(-phiL, s2hat) + N = rotate_z(-phiL, N) + + # Rotation 5: + s1hat = rotate_y(-thetaLJ, s1hat) + s2hat = rotate_y(-thetaLJ, s2hat) + N = rotate_y(-thetaLJ, N) + + # Rotation 6: + phiN = jnp.arctan2(N[1], N[0]) + s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat) + s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat) + + S1 = s1hat * chi1 + S2 = s2hat * chi2 + return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] + + def zenith_azimuth_to_ra_dec( - zenith: Float, azimuth: Float, gmst: Float, delta_x: Float[Array, " 3"] + zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. @@ -272,8 +565,8 @@ def zenith_azimuth_to_ra_dec( Azimuthal angle. gmst : Float Greenwich mean sidereal time. - delta_x : Float - The vector pointing from the first detector to the second detector. + rotation : Float[Array, " 3 3"] + The rotation matrix. Copied and modified from bilby/gw/utils.py @@ -284,26 +577,62 @@ def zenith_azimuth_to_ra_dec( dec : Float Declination. """ - theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + theta, phi = angle_rotation(zenith, azimuth, rotation) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) - ra = ra % (2 * jnp.pi) return ra, dec -def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ - A numerically stable method to evaluate log of - a modified Bessel function of order 0. - It is used in the phase-marginalized likelihood. + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. Parameters - ========== - x: array-like - Value(s) at which to evaluate the function + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. Returns - ======= - array-like: - The natural logarithm of the bessel function + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + phi = ra - gmst + theta = jnp.pi / 2 - dec + phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) + return theta, phi + + +def ra_dec_to_zenith_azimuth( + ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the right ascension and declination to the zenith angle and azimuthal angle. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + rotation : Float[Array, " 3 3"] + The rotation matrix. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. """ - return jnp.log(i0e(x)) + x + theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) + zenith, azimuth = angle_rotation(theta, phi, rotation) + return zenith, azimuth diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4bcb7287..5a28466d 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -105,7 +105,7 @@ class BijectiveTransform(NtoNTransform): inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def inverse(self, y: dict[str, Float]) -> dict[str, Float]: + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Inverse transform the input y to original coordinate x. @@ -118,6 +118,8 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: ------- x : dict[str, Float] The original dictionary. + log_det : Float + The log Jacobian determinant. """ y_copy = y.copy() transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) @@ -135,7 +137,7 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: ) return y_copy, jacobian - def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + def backward(self, y: dict[str, Float]) -> dict[str, Float]: """ Pull back the input y to original coordinate x and return the log Jacobian determinant. @@ -148,8 +150,6 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: ------- x : dict[str, Float] The original dictionary. - log_det : Float - The log Jacobian determinant. """ y_copy = y.copy() output_params = self.inverse_transform_func(y_copy) @@ -164,6 +164,7 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: return y_copy +@jaxtyped(typechecker=typechecker) class ScaleTransform(BijectiveTransform): scale: Float @@ -184,6 +185,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class OffsetTransform(BijectiveTransform): offset: Float @@ -204,6 +206,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class LogitTransform(BijectiveTransform): """ Logit transform following @@ -232,6 +235,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -299,6 +303,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -318,24 +323,27 @@ def logit(x): return jnp.log(x / (1 - x)) super().__init__(name_mapping) - self.original_lower_bound = original_lower_bound - self.original_upper_bound = original_upper_bound + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.original_upper_bound = jnp.atleast_1d(original_upper_bound) self.transform_func = lambda x: { name_mapping[1][i]: logit( - (x[name_mapping[0][i]] - self.original_lower_bound) - / (self.original_upper_bound - self.original_lower_bound) + (x[name_mapping[0][i]] - self.original_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) ) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) + name_mapping[0][i]: ( + self.original_upper_bound[i] - self.original_lower_bound[i] + ) / (1 + jnp.exp(-x[name_mapping[1][i]])) + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } +@jaxtyped(typechecker=typechecker) class SingleSidedUnboundTransform(BijectiveTransform): """ Unbound upper limit transformation diff --git a/src/jimgw/utils.py b/src/jimgw/utils.py new file mode 100644 index 00000000..70c6e166 --- /dev/null +++ b/src/jimgw/utils.py @@ -0,0 +1,37 @@ +import jax.numpy as jnp +from jax.scipy.special import i0e +from jaxtyping import Array, Float + +from jimgw.prior import Prior + + +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output + + +def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: + """ + A numerically stable method to evaluate log of + a modified Bessel function of order 0. + It is used in the phase-marginalized likelihood. + + Parameters + ========== + x: array-like + Value(s) at which to evaluate the function + + Returns + ======= + array-like: + The natural logarithm of the bessel function + """ + return jnp.log(i0e(x)) + x diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index deb3fb98..6fddf9ea 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,6 +8,9 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -27,32 +30,27 @@ fmin = 20.0 fmax = 1024.0 -ifos = ["H1", "L1"] +ifos = [H1, L1] -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -eta_prior = UniformPrior( - 0.125, - 0.25, - parameter_names=["eta"], # Need name transformation in likelihood to work -) +q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) -# Current likelihood sampling will fail and give nan because of large number dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -iota_prior = CosinePrior(parameter_names=["iota"]) +iota_prior = SinePrior(parameter_names=["iota"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -dec_prior = SinePrior(parameter_names=["dec"]) +dec_prior = CosinePrior(parameter_names=["dec"]) prior = CombinePrior( [ Mc_prior, - eta_prior, + q_prior, s1z_prior, s2z_prior, dL_prior, @@ -64,8 +62,28 @@ dec_prior, ] ) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + likelihood = TransientLikelihoodFD( - [H1, L1], + ifos, waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, @@ -88,6 +106,8 @@ jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, diff --git a/test/integration/test_GW150914_PV2.py b/test/integration/test_GW150914_PV2.py new file mode 100644 index 00000000..6be02936 --- /dev/null +++ b/test/integration/test_GW150914_PV2.py @@ -0,0 +1,141 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) +theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) +phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) +theta_1_prior = SinePrior(parameter_names=["theta_1"]) +theta_2_prior = SinePrior(parameter_names=["theta_2"]) +phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) +a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) +a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + Mc_prior, + q_prior, + theta_jn_prior, + phi_jl_prior, + theta_1_prior, + theta_2_prior, + phi_12_prior, + a_1_prior, + a_2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) +] + +likelihood_transforms = [ + SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(15) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 1cbed508..852ded16 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -1,4 +1,5 @@ from jimgw.prior import * +from jimgw.utils import trace_prior_parent import scipy.stats as stats @@ -30,7 +31,6 @@ def test_uniform(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is correct in the support - samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin))) @@ -40,7 +40,6 @@ def test_sine(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite - samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support @@ -48,7 +47,7 @@ def test_sine(self): y = jax.vmap(p.base_prior.base_prior.transform)(x) y = jax.vmap(p.base_prior.transform)(y) y = jax.vmap(p.transform)(y) - assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.sin(y['x'])/2.0)) + assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.sin(y['x'])/2.0)) def test_cosine(self): p = CosinePrior(["x"]) @@ -56,14 +55,13 @@ def test_cosine(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite - samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) y = jax.vmap(p.base_prior.transform)(x) y = jax.vmap(p.transform)(y) - assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.cos(y['x'])/2.0)) + assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.cos(y['x'])/2.0)) def test_uniform_sphere(self): p = UniformSpherePrior(["x"]) @@ -73,12 +71,10 @@ def test_uniform_sphere(self): assert jnp.all(jnp.isfinite(samples['x_theta'])) assert jnp.all(jnp.isfinite(samples['x_phi'])) # Check that the log_prob is finite - samples = {} - for i in range(3): - samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000)) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) + def test_power_law(self): def powerlaw_log_pdf(x, alpha, xmin, xmax): if alpha == -1.0: @@ -96,14 +92,12 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = p.sample(jax.random.PRNGKey(0), 10000) - log_p = jax.vmap(p.log_prob, [0])(samples) + log_p = jax.vmap(p.log_prob, [0])(powerlaw_samples) assert jnp.all(jnp.isfinite(log_p)) # Check that the log_prob is correct in the support - samples = p.sample(jax.random.PRNGKey(0), 10000) - log_prob = jax.vmap(p.log_prob)(samples) - standard_log_prob = powerlaw_log_pdf(samples['x'], alpha, xmin, xmax) + log_prob = jax.vmap(p.log_prob)(powerlaw_samples) + standard_log_prob = powerlaw_log_pdf(powerlaw_samples['x'], alpha, xmin, xmax) # log pdf of powerlaw assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4) @@ -116,4 +110,4 @@ def func(alpha): func(alpha_val) negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] for alpha_val in negative_alpha: - func(alpha_val) + func(alpha_val) \ No newline at end of file