diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 4806efed..d70c2ae1 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -3,11 +3,7 @@ name: Python package -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] +on: [push, pull_request] jobs: build: diff --git a/docs/tutorials/naming_system.md b/docs/tutorials/naming_system.md new file mode 100644 index 00000000..e69de29b diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 0ab1105a..fae0bc98 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,17 +8,51 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior +from jimgw.transforms import BijectiveTransform, NtoMTransform class Jim(object): """ Master class for interfacing with flowMC - """ - def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): - self.Likelihood = likelihood - self.Prior = prior + likelihood: LikelihoodBase + prior: Prior + + # Name of parameters to sample from + sample_transforms: list[BijectiveTransform] + likelihood_transforms: list[NtoMTransform] + parameter_names: list[str] + sampler: Sampler + + def __init__( + self, + likelihood: LikelihoodBase, + prior: Prior, + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], + **kwargs, + ): + self.likelihood = likelihood + self.prior = prior + + self.sample_transforms = sample_transforms + self.likelihood_transforms = likelihood_transforms + self.parameter_names = prior.parameter_names + + if len(sample_transforms) == 0: + print( + "No sample transforms provided. Using prior parameters as sampling parameters" + ) + else: + print("Using sample transforms") + for transform in sample_transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) + + if len(likelihood_transforms) == 0: + print( + "No likelihood transforms provided. Using prior parameters as likelihood parameters" + ) seed = kwargs.get("seed", 0) @@ -33,11 +67,11 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): rng_key, subkey = jax.random.split(rng_key) model = MaskedCouplingRQSpline( - self.Prior.n_dim, num_layers, hidden_size, num_bins, subkey + self.prior.n_dim, num_layers, hidden_size, num_bins, subkey ) - self.Sampler = Sampler( - self.Prior.n_dim, + self.sampler = Sampler( + self.prior.n_dim, rng_key, None, # type: ignore local_sampler, @@ -45,18 +79,44 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): **kwargs, ) - def posterior(self, params: Float[Array, " n_dim"], data: dict): - prior_params = self.Prior.add_name(params.T) - prior = self.Prior.log_prob(prior_params) - return ( - self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior - ) + def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: + """ + Turn an array into a dictionary - def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): - if initial_guess.size == 0: - initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T - self.Sampler.sample(initial_guess, None) # type: ignore + Parameters + ---------- + x : Array + An array of parameters. Shape (n_dim,). + """ + + return dict(zip(self.parameter_names, x)) + + def posterior(self, params: Float[Array, " n_dim"], data: dict): + named_params = self.add_name(params) + transform_jacobian = 0.0 + for transform in reversed(self.sample_transforms): + named_params, jacobian = transform.inverse(named_params) + transform_jacobian += jacobian + prior = self.prior.log_prob(named_params) + transform_jacobian + for transform in self.likelihood_transforms: + named_params = transform.forward(named_params) + return self.likelihood.evaluate(named_params, data) + prior + + def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): + if initial_position.size == 0: + initial_guess = [] + for _ in range(self.sampler.n_chains): + flag = True + while flag: + key = jax.random.split(key)[1] + guess = self.prior.sample(key, 1) + for transform in self.sample_transforms: + guess = transform.forward(guess) + guess = jnp.array([i for i in guess.values()]).T[0] + flag = not jnp.all(jnp.isfinite(guess)) + initial_guess.append(guess) + initial_position = jnp.array(initial_guess) + self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( self, @@ -67,7 +127,7 @@ def maximize_likelihood( ): key = jax.random.PRNGKey(seed) set_nwalkers = set_nwalkers - initial_guess = self.Prior.sample(key, set_nwalkers) + initial_guess = self.prior.sample(key, set_nwalkers) def negative_posterior(x: Float[Array, " n_dim"]): return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet @@ -78,7 +138,7 @@ def negative_posterior(x: Float[Array, " n_dim"]): print("Done compiling") print("Starting the optimizer") - optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True) + optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True) _ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return best_fit @@ -89,22 +149,24 @@ def print_summary(self, transform: bool = True): """ - train_summary = self.Sampler.get_sampler_state(training=True) - production_summary = self.Sampler.get_sampler_state(training=False) + train_summary = self.sampler.get_sampler_state(training=True) + production_summary = self.sampler.get_sampler_state(training=False) - training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T - training_chain = self.Prior.add_name(training_chain) + training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T + training_chain = self.add_name(training_chain) if transform: - training_chain = self.Prior.transform(training_chain) + for sample_transform in reversed(self.sample_transforms): + training_chain = sample_transform.backward(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T - production_chain = self.Prior.add_name(production_chain) + production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T + production_chain = self.add_name(production_chain) if transform: - production_chain = self.Prior.transform(production_chain) + for sample_transform in reversed(self.sample_transforms): + production_chain = sample_transform.backward(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -156,11 +218,14 @@ def get_samples(self, training: bool = False) -> dict: """ if training: - chains = self.Sampler.get_sampler_state(training=True)["chains"] + chains = self.sampler.get_sampler_state(training=True)["chains"] else: - chains = self.Sampler.get_sampler_state(training=False)["chains"] + chains = self.sampler.get_sampler_state(training=False)["chains"] - chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1))) + chains = chains.transpose(2, 0, 1) + chains = self.add_name(chains) + for sample_transform in reversed(self.sample_transforms): + chains = sample_transform.backward(chains) return chains def plot(self): diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f1827753..5227ffa5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,14 +1,20 @@ from dataclasses import field -from typing import Callable, Union import jax import jax.numpy as jnp -from flowMC.nfmodel.base import Distribution -from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped from beartype import beartype as typechecker -from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.single_event.detector import GroundBased2G, detector_preset -from astropy.time import Time +from flowMC.nfmodel.base import Distribution +from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped + +from jimgw.transforms import ( + BijectiveTransform, + LogitTransform, + ScaleTransform, + OffsetTransform, + ArcSineTransform, + PowerLawTransform, + ParetoTransform, +) class Prior(Distribution): @@ -21,58 +27,21 @@ class Prior(Distribution): the names of the parameters and the transforms that are applied to them. """ - naming: list[str] - transforms: dict[str, tuple[str, Callable]] = field(default_factory=dict) + parameter_names: list[str] + composite: bool = False @property - def n_dim(self): - return len(self.naming) + def n_dim(self) -> int: + return len(self.parameter_names) - def __init__( - self, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {} - ): + def __init__(self, parameter_names: list[str]): """ Parameters ---------- - naming : list[str] + parameter_names : list[str] A list of names for the parameters of the prior. - transforms : dict[tuple[str,Callable]] - A dictionary of transforms to apply to the parameters. The keys are - the names of the parameters and the values are a tuple of the name - of the transform and the transform itself. """ - self.naming = naming - self.transforms = {} - - def make_lambda(name): - return lambda x: x[name] - - for name in naming: - if name in transforms: - self.transforms[name] = transforms[name] - else: - # Without the function, the lambda will refer to the variable name instead of its value, - # which will make lambda reference the last value of the variable name - self.transforms[name] = (name, make_lambda(name)) - - def transform(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Apply the transforms to the parameters. - - Parameters - ---------- - x : dict - A dictionary of parameters. Names should match the ones in the prior. - - Returns - ------- - x : dict - A dictionary of parameters with the transforms applied. - """ - output = {} - for value in self.transforms.values(): - output[value[0]] = value[1](x) - return output + self.parameter_names = parameter_names def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ @@ -84,43 +53,33 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: An array of parameters. Shape (n_dim,). """ - return dict(zip(self.naming, x)) + return dict(zip(self.parameter_names, x)) def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: raise NotImplementedError - def log_prob(self, x: dict[str, Array]) -> Float: + def log_prob(self, z: dict[str, Array]) -> Float: raise NotImplementedError @jaxtyped(typechecker=typechecker) -class Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 +class LogisticDistribution(Prior): def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax})" + return f"LogisticDistribution(parameter_names={self.parameter_names})" - def __init__( - self, - xmin: Float, - xmax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Uniform needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin + def __init__(self, parameter_names: list[str], **kwargs): + super().__init__(parameter_names) + self.composite = False + assert self.n_dim == 1, "LogisticDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: """ - Sample from a uniform distribution. + Sample from a logistic distribution. Parameters ---------- @@ -135,72 +94,33 @@ def sample( Samples from the distribution. The keys are the names of the parameters. """ - samples = jax.random.uniform( - rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax - ) + samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) + samples = jnp.log(samples / (1 - samples)) return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - return output + jnp.log(1.0 / (self.xmax - self.xmin)) + def log_prob(self, z: dict[str, Float]) -> Float: + variable = z[self.parameter_names[0]] + return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) @jaxtyped(typechecker=typechecker) -class Unconstrained_Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 +class StandardNormalDistribution(Prior): def __repr__(self): - return f"Unconstrained_Uniform(xmin={self.xmin}, xmax={self.xmax})" + return f"StandardNormalDistribution(parameter_names={self.parameter_names})" - def __init__( - self, - xmin: Float, - xmax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - local_transform = self.transforms - - def new_transform(param): - param[self.naming[0]] = self.to_range(param[self.naming[0]]) - return local_transform[self.naming[0]][1](param) - - self.transforms = { - self.naming[0]: (local_transform[self.naming[0]][0], new_transform) - } - - def to_range(self, x: Float) -> Float: - """ - Transform the parameters to the range of the prior. - - Parameters - ---------- - x : Float - The parameters to transform. - - Returns - ------- - x : dict - A dictionary of parameters with the transforms applied. - """ - return (self.xmax - self.xmin) / (1 + jnp.exp(-x)) + self.xmin + def __init__(self, parameter_names: list[str], **kwargs): + super().__init__(parameter_names) + self.composite = False + assert ( + self.n_dim == 1 + ), "StandardNormalDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: """ - Sample from a uniform distribution. + Sample from a standard normal distribution. Parameters ---------- @@ -211,513 +131,350 @@ def sample( Returns ------- - samples : - An array of shape (n_samples, n_dim) containing the samples. + samples : dict + Samples from the distribution. The keys are the names of the parameters. """ - samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) - samples = jnp.log(samples / (1 - samples)) + samples = jax.random.normal(rng_key, (n_samples,)) return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) + def log_prob(self, z: dict[str, Float]) -> Float: + variable = z[self.parameter_names[0]] + return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) -class Sphere(Prior): +class SequentialTransformPrior(Prior): """ - A prior on a sphere represented by Cartesian coordinates. - - Magnitude is sampled from a uniform distribution. + Transform a prior distribution by applying a sequence of transforms. + The space before the transform is named as x, + and the space after the transform is named as z """ - def __repr__(self): - return f"Sphere(naming={self.naming})" - - def __init__(self, naming: list[str], **kwargs): - name = naming[0] - self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] - self.transforms = { - self.naming[0]: ( - f"{naming}_x", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.cos(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[1]: ( - f"{naming}_y", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.sin(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[2]: ( - f"{naming}_z", - lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 3) - theta = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) - mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) - return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - theta = x[self.naming[0]] - phi = x[self.naming[1]] - mag = x[self.naming[2]] - output = jnp.where( - (mag > 1) - | (mag < 0) - | (phi > 2 * jnp.pi) - | (phi < 0) - | (theta > jnp.pi) - | (theta < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), - ) - return output - - -@jaxtyped(typechecker=typechecker) -class AlignedSpin(Prior): - """ - Prior distribution for the aligned (z) component of the spin. - - This assume the prior distribution on the spin magnitude to be uniform in [0, amax] - with its orientation uniform on a sphere - - p(chi) = -log(|chi| / amax) / 2 / amax - - This is useful when comparing results between an aligned-spin run and - a precessing spin run. - - See (A7) of https://arxiv.org/abs/1805.10457. - """ - - amax: Float = 0.99 - chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + base_prior: Prior + transforms: list[BijectiveTransform] def __repr__(self): - return f"Alignedspin(amax={self.amax}, naming={self.naming})" + return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" def __init__( self, - amax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, + base_prior: Prior, + transforms: list[BijectiveTransform], ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" - self.amax = amax - # build the interpolation table for the ppf of the one-sided distribution - chi_axis = jnp.linspace(1e-31, self.amax, num=1000) - cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax - self.chi_axis = chi_axis - self.cdf_vals = cdf_vals - - @property - def xmin(self): - return -self.amax - - @property - def xmax(self): - return self.amax + self.base_prior = base_prior + self.transforms = transforms + self.parameter_names = base_prior.parameter_names + for transform in transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) + self.composite = True def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from the Alignedspin distribution. - - for chi > 0; - p(chi) = -log(chi / amax) / amax # halved normalization constant - cdf(chi) = -chi * (log(chi / amax) - 1) / amax - - Since there is a pole at chi=0, we will sample with the following steps - 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise - 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) - 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) - 3. Map the quantile to chi via the ppf by checking against the table - built during the initialization - 4. add back the sign - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. + output = self.base_prior.sample(rng_key, n_samples) + return jax.vmap(self.transform)(output) + def log_prob(self, z: dict[str, Float]) -> Float: """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - # 1. calculate the sign of chi from the q_samples - sign_samples = jnp.where( - q_samples >= 0.5, - jnp.zeros_like(q_samples) + 1.0, - jnp.zeros_like(q_samples) - 1.0, - ) - # 2. remap q_samples - q_samples = jnp.where( - q_samples >= 0.5, - 2 * (q_samples - 0.5), - 2 * (0.5 - q_samples), - ) - # 3. map the quantile to chi via interpolation - samples = jnp.interp( - q_samples, - self.cdf_vals, - self.chi_axis, - ) - # 4. add back the sign - samples *= sign_samples - - return self.add_name(samples[None]) + Evaluating the probability of the transformed variable z. + This is what flowMC should sample from + """ + output = 0 + for transform in reversed(self.transforms): + z, log_jacobian = transform.inverse(z) + output += log_jacobian + output += self.base_prior.log_prob(z) + return output - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_p = jnp.where( - (variable >= self.amax) | (variable <= -self.amax), - jnp.zeros_like(variable) - jnp.inf, - jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), - ) - return log_p + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + for transform in self.transforms: + x = transform.forward(x) + return x -@jaxtyped(typechecker=typechecker) -class EarthFrame(Prior): +class CombinePrior(Prior): """ - Prior distribution for sky location in Earth frame. + A prior class constructed by joinning multiple priors together to form a multivariate prior. + This assumes the priors composing the Combine class are independent. """ - ifos: list = field(default_factory=list) - gmst: float = 0.0 - delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) + base_prior: list[Prior] = field(default_factory=list) def __repr__(self): - return f"EarthFrame(naming={self.naming})" - - def __init__(self, gps: Float, ifos: list, **kwargs): - self.naming = ["zenith", "azimuth"] - if len(ifos) < 2: - return ValueError( - "At least two detectors are needed to define the Earth frame" - ) - elif isinstance(ifos[0], str): - self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] - elif isinstance(ifos[0], GroundBased2G): - self.ifos = ifos[:1] - else: - return ValueError( - "ifos should be a list of detector names or GroundBased2G objects" - ) - self.gmst = float( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + return ( + f"Combine(priors={self.base_prior}, parameter_names={self.parameter_names})" ) - self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - - self.transforms = { - "azimuth": ( - "ra", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[0], - ), - "zenith": ( - "dec", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[1], - ), - } + + def __init__( + self, + priors: list[Prior], + ): + parameter_names = [] + for prior in priors: + parameter_names += prior.parameter_names + self.base_prior = priors + self.parameter_names = parameter_names + self.composite = True def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 2) - zenith = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - azimuth = jax.random.uniform( - rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi - ) - return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - zenith = x["zenith"] - azimuth = x["azimuth"] - output = jnp.where( - (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.zeros_like(0), - ) - return output + jnp.log(jnp.sin(zenith)) + output = {} + for prior in self.base_prior: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output + def log_prob(self, z: dict[str, Float]) -> Float: + output = 0.0 + for prior in self.base_prior: + output += prior.log_prob(z) + return output -@jaxtyped(typechecker=typechecker) -class PowerLaw(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ x^{\alpha} - """ - xmin: float = 0.0 - xmax: float = 1.0 - alpha: float = 0.0 - normalization: float = 1.0 +@jaxtyped(typechecker=typechecker) +class UniformPrior(SequentialTransformPrior): + xmin: float + xmax: float def __repr__(self): - return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + return f"UniformPrior(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" def __init__( self, xmin: float, xmax: float, - alpha: Union[Int, float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, + parameter_names: list[str], ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin > 0.0, "With negative alpha, xmin must > 0" - assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformPrior needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - self.alpha = alpha - if alpha == -1: - self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) - else: - self.normalization = (1 + self.alpha) / ( - self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a power-law distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - if self.alpha == -1: - samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) - else: - samples = ( - self.xmin ** (1.0 + self.alpha) - + q_samples - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), + super().__init__( + LogisticDistribution([f"{self.parameter_names[0]}_base"]), + [ + LogitTransform( + ( + [f"{self.parameter_names[0]}_base"], + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + ) + ), + ScaleTransform( + ( + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + [f"{self.parameter_names[0]}-({xmin})"], + ), + xmax - xmin, + ), + OffsetTransform( + ([f"{self.parameter_names[0]}-({xmin})"], self.parameter_names), + xmin, + ), + ], ) - log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) - return log_p + log_in_range @jaxtyped(typechecker=typechecker) -class Exponential(Prior): +class SinePrior(SequentialTransformPrior): """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ exp(\alpha x) + A prior distribution where the pdf is proportional to sin(x) in the range [0, pi]. """ - xmin: float = 0.0 - xmax: float = jnp.inf - alpha: float = -1.0 - normalization: float = 1.0 - def __repr__(self): - return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: Float, - xmax: Float, - alpha: Union[Int, Float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin != -jnp.inf, "With negative alpha, xmin must finite" - if alpha > 0.0: - assert xmax != jnp.inf, "With positive alpha, xmax must finite" - assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" - assert self.n_dim == 1, "Exponential needs to be 1D distributions" - - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - - self.normalization = self.alpha / ( - jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) + return f"SinePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 1, "SinePrior needs to be 1D distributions" + super().__init__( + CosinePrior([f"{self.parameter_names[0]}-pi/2"]), + [ + OffsetTransform( + ( + ( + [f"{self.parameter_names[0]}-pi/2"], + [f"{self.parameter_names[0]}"], + ) + ), + jnp.pi / 2, + ) + ], ) - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a exponential distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - samples = ( - self.xmin - + jnp.log1p( - q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) - ) - / self.alpha - ) - return self.add_name(samples[None]) +@jaxtyped(typechecker=typechecker) +class CosinePrior(SequentialTransformPrior): + """ + A prior distribution where the pdf is proportional to cos(x) in the range [-pi/2, pi/2]. + """ - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), + def __repr__(self): + return f"CosinePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" + super().__init__( + UniformPrior(-1.0, 1.0, [f"sin({self.parameter_names[0]})"]), + [ + ArcSineTransform( + ([f"sin({self.parameter_names[0]})"], [self.parameter_names[0]]) + ) + ], ) - log_p = self.alpha * variable + jnp.log(self.normalization) - return log_p + log_in_range @jaxtyped(typechecker=typechecker) -class Normal(Prior): - mean: Float = 0.0 - std: Float = 1.0 +class UniformSpherePrior(CombinePrior): def __repr__(self): - return f"Normal(mean={self.mean}, std={self.std})" - - def __init__( - self, - mean: Float, - std: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Normal needs to be 1D distributions" - self.mean = mean - self.std = std - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a normal distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - samples = jax.random.normal(rng_key, (n_samples,)) - samples = self.mean + samples * self.std - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = ( - -0.5 * jnp.log(2 * jnp.pi) - - jnp.log(self.std) - - 0.5 * ((variable - self.mean) / self.std) ** 2 + return f"UniformSpherePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str], **kwargs): + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" + self.parameter_names = [ + f"{self.parameter_names[0]}_mag", + f"{self.parameter_names[0]}_theta", + f"{self.parameter_names[0]}_phi", + ] + super().__init__( + [ + UniformPrior(0.0, 1.0, [self.parameter_names[0]]), + SinePrior([self.parameter_names[1]]), + UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + ] ) - return output -class Composite(Prior): - priors: list[Prior] = field(default_factory=list) +@jaxtyped(typechecker=typechecker) +class PowerLawPrior(SequentialTransformPrior): + xmin: float + xmax: float + alpha: float def __repr__(self): - return f"Composite(priors={self.priors}, naming={self.naming})" + return f"PowerLawPrior(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" def __init__( self, - priors: list[Prior], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, + xmin: float, + xmax: float, + alpha: float, + parameter_names: list[str], ): - naming = [] - self.transforms = {} - for prior in priors: - naming += prior.naming - self.transforms.update(prior.transforms) - self.priors = priors - self.naming = naming - self.transforms.update(transforms) + self.parameter_names = parameter_names + assert self.n_dim == 1, "Power law needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + assert self.xmin < self.xmax, "xmin must be less than xmax" + assert self.xmin > 0.0, "x must be positive" + if self.alpha == -1.0: + transform = ParetoTransform( + ([f"{self.parameter_names[0]}_before_transform"], self.parameter_names), + xmin, + xmax, + ) + else: + transform = PowerLawTransform( + ([f"{self.parameter_names[0]}_before_transform"], self.parameter_names), + xmin, + xmax, + alpha, + ) + super().__init__( + LogisticDistribution([f"{self.parameter_names[0]}_base"]), + [ + LogitTransform( + ( + [f"{self.parameter_names[0]}_base"], + [f"{self.parameter_names[0]}_before_transform"], + ) + ), + transform, + ], + ) - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - output = {} - for prior in self.priors: - rng_key, subkey = jax.random.split(rng_key) - output.update(prior.sample(subkey, n_samples)) - return output - def log_prob(self, x: dict[str, Float]) -> Float: - output = 0.0 - for prior in self.priors: - output += prior.log_prob(x) - return output +# ====================== Things below may need rework ====================== + +# @jaxtyped(typechecker=typechecker) +# class Exponential(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ exp(\alpha x) +# """ + +# xmin: float = 0.0 +# xmax: float = jnp.inf +# alpha: float = -1.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: Float, +# xmax: Float, +# alpha: Union[Int, Float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin != -jnp.inf, "With negative alpha, xmin must finite" +# if alpha > 0.0: +# assert xmax != jnp.inf, "With positive alpha, xmax must finite" +# assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" +# assert self.n_dim == 1, "Exponential needs to be 1D distributions" + +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha + +# self.normalization = self.alpha / ( +# jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a exponential distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# samples = ( +# self.xmin +# + jnp.log1p( +# q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) +# ) +# / self.alpha +# ) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * variable + jnp.log(self.normalization) +# return log_p + log_in_range diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index f10aeed1..00e6ce6b 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -12,8 +12,9 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior from jimgw.single_event.detector import Detector -from jimgw.single_event.utils import log_i0 +from jimgw.utils import log_i0 from jimgw.single_event.waveform import Waveform +from jimgw.transforms import BijectiveTransform, NtoMTransform class SingleEventLiklihood(LikelihoodBase): @@ -184,8 +185,6 @@ def __init__( self, detectors: list[Detector], waveform: Waveform, - prior: Prior, - bounds: Float[Array, " n_dim 2"], n_bins: int = 100, trigger_time: float = 0, duration: float = 4, @@ -194,6 +193,9 @@ def __init__( n_steps: int = 2000, ref_params: dict = {}, reference_waveform: Optional[Waveform] = None, + prior: Optional[Prior] = None, + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], **kwargs, ) -> None: super().__init__( @@ -254,17 +256,24 @@ def __init__( ) self.freq_grid_low = freq_grid[:-1] - if not ref_params: + if ref_params: + self.ref_params = ref_params + print(f"Reference parameters provided, which are {self.ref_params}") + elif prior: print("No reference parameters are provided, finding it...") ref_params = self.maximize_likelihood( - bounds=bounds, prior=prior, popsize=popsize, n_steps=n_steps + prior=prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + popsize=popsize, + n_steps=n_steps, ) self.ref_params = {key: float(value) for key, value in ref_params.items()} print(f"The reference parameters are {self.ref_params}") else: - self.ref_params = ref_params - print(f"Reference parameters provided, which are {self.ref_params}") - + raise ValueError( + "Either reference parameters or parameter names must be provided" + ) # safe guard for the reference parameters # since ripple cannot handle eta=0.25 if jnp.isclose(self.ref_params["eta"], 0.25): @@ -542,25 +551,54 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): def maximize_likelihood( self, prior: Prior, - bounds: Float[Array, " n_dim 2"], + likelihood_transforms: list[NtoMTransform], + sample_transforms: list[BijectiveTransform], popsize: int = 100, n_steps: int = 2000, ): + parameter_names = prior.parameter_names + for transform in sample_transforms: + parameter_names = transform.propagate_name(parameter_names) + def y(x): - return -self.evaluate_original(prior.transform(prior.add_name(x)), {}) + named_params = dict(zip(parameter_names, x)) + for transform in reversed(sample_transforms): + named_params = transform.backward(named_params) + for transform in likelihood_transforms: + named_params = transform.forward(named_params) + return -self.evaluate_original(named_params, {}) print("Starting the optimizer") + optimizer = optimization_Adam( n_steps=n_steps, learning_rate=0.001, noise_level=1 ) - initial_position = jnp.array( - list(prior.sample(jax.random.PRNGKey(0), popsize).values()) - ).T + + key = jax.random.PRNGKey(0) + initial_position = [] + for _ in range(popsize): + flag = True + while flag: + key = jax.random.split(key)[1] + guess = prior.sample(key, 1) + for transform in sample_transforms: + guess = transform.forward(guess) + guess = jnp.array([i for i in guess.values()]).T[0] + flag = not jnp.all(jnp.isfinite(guess)) + initial_position.append(guess) + initial_position = jnp.array(initial_position) rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position ) - best_fit = optimized_positions[jnp.nanargmin(summary["final_log_prob"])] - return prior.transform(prior.add_name(best_fit)) + + best_fit = optimized_positions[jnp.argmin(summary["final_log_prob"])] + + named_params = dict(zip(parameter_names, best_fit)) + for transform in reversed(sample_transforms): + named_params = transform.backward(named_params) + for transform in likelihood_transforms: + named_params = transform.forward(named_params) + return named_params likelihood_presets = { diff --git a/src/jimgw/single_event/prior.py b/src/jimgw/single_event/prior.py new file mode 100644 index 00000000..51a754eb --- /dev/null +++ b/src/jimgw/single_event/prior.py @@ -0,0 +1,179 @@ +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import jaxtyped + +from jimgw.prior import ( + Prior, + CombinePrior, + UniformPrior, + PowerLawPrior, + SinePrior, +) + + +@jaxtyped(typechecker=typechecker) +class UniformSpherePrior(CombinePrior): + + def __repr__(self): + return f"UniformSpherePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str], **kwargs): + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" + self.parameter_names = [ + f"{self.parameter_names[0]}_mag", + f"{self.parameter_names[0]}_theta", + f"{self.parameter_names[0]}_phi", + ] + super().__init__( + [ + UniformPrior(0.0, 1.0, [self.parameter_names[0]]), + SinePrior([self.parameter_names[1]]), + UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + ] + ) + + +@jaxtyped(typechecker=typechecker) +class UniformComponentChirpMassPrior(PowerLawPrior): + """ + A prior in the range [xmin, xmax) for chirp mass which assumes the + component masses to be uniformly distributed. + + p(M_c) ~ M_c + """ + + def __repr__(self): + return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float): + super().__init__(xmin, xmax, 1.0, ["M_c"]) + + +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output + + +# ====================== Things below may need rework ====================== + + +# @jaxtyped(typechecker=typechecker) +# class AlignedSpin(Prior): +# """ +# Prior distribution for the aligned (z) component of the spin. + +# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] +# with its orientation uniform on a sphere + +# p(chi) = -log(|chi| / amax) / 2 / amax + +# This is useful when comparing results between an aligned-spin run and +# a precessing spin run. + +# See (A7) of https://arxiv.org/abs/1805.10457. +# """ + +# amax: Float = 0.99 +# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) +# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + +# def __repr__(self): +# return f"Alignedspin(amax={self.amax}, naming={self.naming})" + +# def __init__( +# self, +# amax: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" +# self.amax = amax + +# # build the interpolation table for the ppf of the one-sided distribution +# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) +# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax +# self.chi_axis = chi_axis +# self.cdf_vals = cdf_vals + +# @property +# def xmin(self): +# return -self.amax + +# @property +# def xmax(self): +# return self.amax + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from the Alignedspin distribution. + +# for chi > 0; +# p(chi) = -log(chi / amax) / amax # halved normalization constant +# cdf(chi) = -chi * (log(chi / amax) - 1) / amax + +# Since there is a pole at chi=0, we will sample with the following steps +# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise +# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) +# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) +# 3. Map the quantile to chi via the ppf by checking against the table +# built during the initialization +# 4. add back the sign + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# # 1. calculate the sign of chi from the q_samples +# sign_samples = jnp.where( +# q_samples >= 0.5, +# jnp.zeros_like(q_samples) + 1.0, +# jnp.zeros_like(q_samples) - 1.0, +# ) +# # 2. remap q_samples +# q_samples = jnp.where( +# q_samples >= 0.5, +# 2 * (q_samples - 0.5), +# 2 * (0.5 - q_samples), +# ) +# # 3. map the quantile to chi via interpolation +# samples = jnp.interp( +# q_samples, +# self.cdf_vals, +# self.chi_axis, +# ) +# # 4. add back the sign +# samples *= sign_samples + +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_p = jnp.where( +# (variable >= self.amax) | (variable <= -self.amax), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), +# ) +# return log_p diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py new file mode 100644 index 00000000..cb60e98b --- /dev/null +++ b/src/jimgw/single_event/transforms.py @@ -0,0 +1,155 @@ +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped +from astropy.time import Time + +from jimgw.single_event.detector import GroundBased2G +from jimgw.transforms import ( + BijectiveTransform, + NtoNTransform, + reverse_bijective_transform, +) +from jimgw.single_event.utils import ( + m1_m2_to_Mc_q, + Mc_q_to_m1_m2, + m1_m2_to_Mc_eta, + Mc_eta_to_m1_m2, + q_to_eta, + eta_to_q, + ra_dec_to_zenith_azimuth, + zenith_azimuth_to_ra_dec, + euler_rotation, + spin_to_cartesian_spin, +) + + +@jaxtyped(typechecker=typechecker) +class SpinToCartesianSpinTransform(NtoNTransform): + """ + Spin to Cartesian spin transformation + """ + + freq_ref: Float + + def __init__( + self, + freq_ref: Float, + ): + name_mapping = ( + ["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], + ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"], + ) + super().__init__(name_mapping) + + self.freq_ref = freq_ref + + def named_transform(x): + iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( + x["theta_jn"], + x["phi_jl"], + x["theta_1"], + x["theta_2"], + x["phi_12"], + x["a_1"], + x["a_2"], + x["M_c"], + x["q"], + self.freq_ref, + x["phase_c"], + ) + return { + "iota": iota, + "s1_x": s1x, + "s1_y": s1y, + "s1_z": s1z, + "s2_x": s2x, + "s2_y": s2y, + "s2_z": s2z, + } + + self.transform_func = named_transform + + +@jaxtyped(typechecker=typechecker) +class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): + """ + Transform sky frame to detector frame sky position + """ + + gmst: Float + rotation: Float[Array, " 3 3"] + rotation_inv: Float[Array, " 3 3"] + + def __init__( + self, + gps_time: Float, + ifos: GroundBased2G, + ): + name_mapping = (["ra", "dec"], ["zenith", "azimuth"]) + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + delta_x = ifos[0].vertex - ifos[1].vertex + self.rotation = euler_rotation(delta_x) + self.rotation_inv = jnp.linalg.inv(self.rotation) + + def named_transform(x): + zenith, azimuth = ra_dec_to_zenith_azimuth( + x["ra"], x["dec"], self.gmst, self.rotation + ) + return {"zenith": zenith, "azimuth": azimuth} + + self.transform_func = named_transform + + def named_inverse_transform(x): + ra, dec = zenith_azimuth_to_ra_dec( + x["zenith"], x["azimuth"], self.gmst, self.rotation_inv + ) + return {"ra": ra, "dec": dec} + + self.inverse_transform_func = named_inverse_transform + +def named_m1_m2_to_Mc_q(x): + Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) + return {"M_c": Mc, "q": q} + +def named_Mc_q_to_m1_m2(x): + m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} + +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) +ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 + +def named_m1_m2_to_Mc_eta(x): + Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) + return {"M_c": Mc, "eta": eta} + +def named_Mc_eta_to_m1_m2(x): + m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) + return {"m_1": m1, "m_2": m2} + +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 + +def named_q_to_eta(x): + return {"eta": q_to_eta(x["q"])} +def named_eta_to_q(x): + return {"q": eta_to_q(x["eta"])} +MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) +MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta +MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q + + +ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform( + ComponentMassesToChirpMassMassRatioTransform +) +ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform( + ComponentMassesToChirpMassSymmetricMassRatioTransform +) +SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform( + MassRatioToSymmetricMassRatioTransform +) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 557a79ea..fb35bf27 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -1,6 +1,5 @@ import jax.numpy as jnp from jax.scipy.integrate import trapezoid -from jax.scipy.special import i0e from jaxtyping import Array, Float from jimgw.constants import Msun @@ -39,7 +38,7 @@ def inner_product( return 4.0 * jnp.real(trapezoid(integrand, dx=df)) -def m1m2_to_Mq(m1: Float, m2: Float): +def m1_m2_to_M_q(m1: Float, m2: Float): """ Transforming the primary mass m1 and secondary mass m2 to the Total mass M and mass ratio q. @@ -58,12 +57,12 @@ def m1m2_to_Mq(m1: Float, m2: Float): q : Float Mass ratio. """ - M_tot = jnp.log(m1 + m2) - q = jnp.log(m2 / m1) - jnp.log(1 - m2 / m1) + M_tot = m1 + m2 + q = m2 / m1 return M_tot, q -def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): +def M_q_to_m1_m2(M_tot: Float, q: Float): """ Transforming the Total mass M and mass ratio q to the primary mass m1 and secondary mass m2. @@ -82,21 +81,45 @@ def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): m2 : Float Secondary mass. """ - M_tot = jnp.exp(trans_M_tot) - q = 1.0 / (1 + jnp.exp(-trans_q)) m1 = M_tot / (1 + q) m2 = m1 * q return m1, m2 -def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: +def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: """ - Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c + and mass ratio q. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + """ + M_tot = m1 + m2 + eta = m1 * m2 / M_tot**2 + M_c = M_tot * eta ** (3.0 / 5) + q = m2 / m1 + return M_c, q + + +def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: + """ + Transforming the chirp mass M_c and mass ratio q to the primary mass m1 and secondary mass m2. Parameters ---------- - Mc : Float + M_c : Float Chirp mass. q : Float Mass ratio. @@ -109,36 +132,148 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: Secondary mass. """ eta = q / (1 + q) ** 2 - M_tot = Mc / eta ** (3.0 / 5) + M_tot = M_c / eta ** (3.0 / 5) m1 = M_tot / (1 + q) m2 = m1 * q return m1, m2 -def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: +def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: """ - Transforming the right ascension ra and declination dec to the polar angle - theta and azimuthal angle phi. + Transforming the primary mass m1 and secondary mass m2 to the total mass M + and symmetric mass ratio eta. Parameters ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. + m1 : Float + Primary mass. + m2 : Float + Secondary mass. Returns ------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. + M : Float + Total mass. + eta : Float + Symmetric mass ratio. """ - phi = ra - gmst - theta = jnp.pi / 2 - dec - return theta, phi + M_tot = m1 + m2 + eta = m1 * m2 / M_tot**2 + return M_tot, eta + + +def M_eta_to_m1_m2(M_tot: Float, eta: Float) -> tuple[Float, Float]: + """ + Transforming the total mass M and symmetric mass ratio eta to the primary mass m1 + and secondary mass m2. + + Parameters + ---------- + M : Float + Total mass. + eta : Float + Symmetric mass ratio. + + Returns + ------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + """ + m1 = M_tot * (1 + jnp.sqrt(1 - 4 * eta)) / 2 + m2 = M_tot * (1 - jnp.sqrt(1 - 4 * eta)) / 2 + return m1, m2 + + +def m1_m2_to_Mc_eta(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c + and symmetric mass ratio eta. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M_c : Float + Chirp mass. + eta : Float + Symmetric mass ratio. + """ + M = m1 + m2 + eta = m1 * m2 / M**2 + M_c = M * eta ** (3.0 / 5) + return M_c, eta + + +def Mc_eta_to_m1_m2(M_c: Float, eta: Float) -> tuple[Float, Float]: + """ + Transforming the chirp mass M_c and symmetric mass ratio eta to the primary mass m1 + and secondary mass m2. + + Parameters + ---------- + M_c : Float + Chirp mass. + eta : Float + Symmetric mass ratio. + + Returns + ------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + """ + M = M_c / eta ** (3.0 / 5) + m1 = M * (1 + jnp.sqrt(1 - 4 * eta)) / 2 + m2 = M * (1 - jnp.sqrt(1 - 4 * eta)) / 2 + return m1, m2 + + +def q_to_eta(q: Float) -> Float: + """ + Transforming the chirp mass M_c and mass ratio q to the symmetric mass ratio eta. + + Parameters + ---------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + + Returns + ------- + eta : Float + Symmetric mass ratio. + """ + eta = q / (1 + q) ** 2 + return eta + + +def eta_to_q(eta: Float) -> Float: + """ + Transforming the symmetric mass ratio eta to the mass ratio q. + + Copied and modified from bilby/gw/conversion.py + + Parameters + ---------- + eta : Float + Symmetric mass ratio. + + Returns + ------- + q : Float + Mass ratio. + """ + temp = 1 / eta / 2 - 1 + return temp - (temp**2 - 1) ** 0.5 def spin_to_cartesian_spin( @@ -310,11 +445,10 @@ def euler_rotation(delta_x: Float[Array, " 3"]): Copied and modified from bilby-cython/geometry.pyx """ - norm = jnp.power( - delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5 - ) + norm = jnp.linalg.vector_norm(delta_x) + cos_beta = delta_x[2] / norm - sin_beta = jnp.power(1 - cos_beta**2, 0.5) + sin_beta = jnp.sqrt(1 - cos_beta**2) alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0]) gamma = jnp.atan2(delta_x[1], delta_x[0]) @@ -343,8 +477,8 @@ def euler_rotation(delta_x: Float[Array, " 3"]): return rotation -def zenith_azimuth_to_theta_phi( - zenith: Float, azimuth: Float, delta_x: Float[Array, " 3"] +def angle_rotation( + zenith: Float, azimuth: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame. @@ -357,8 +491,8 @@ def zenith_azimuth_to_theta_phi( Zenith angle. azimuth : Float Azimuthal angle. - delta_x : Float - The vector pointing from the first detector to the second detector. + rotation : Float[Array, " 3 3"] + The rotation matrix. Returns ------- @@ -372,8 +506,6 @@ def zenith_azimuth_to_theta_phi( sin_zenith = jnp.sin(zenith) cos_zenith = jnp.cos(zenith) - rotation = euler_rotation(delta_x) - theta = jnp.acos( rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth @@ -389,7 +521,7 @@ def zenith_azimuth_to_theta_phi( + rotation[0][2] * cos_zenith, ) + 2 * jnp.pi, - (2 * jnp.pi), + 2 * jnp.pi, ) return theta, phi @@ -416,11 +548,170 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F """ ra = phi + gmst dec = jnp.pi / 2 - theta + ra = ra % (2 * jnp.pi) return ra, dec +def spin_to_cartesian_spin( + thetaJN: Float, + phiJL: Float, + theta1: Float, + theta2: Float, + phi12: Float, + chi1: Float, + chi2: Float, + M_c: Float, + q: Float, + fRef: Float, + phiRef: Float, +) -> tuple[Float, Float, Float, Float, Float, Float, Float]: + """ + Transforming the spin parameters + + The code is based on the approach used in LALsimulation: + https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html + + Parameters: + ------- + thetaJN: Float + Zenith angle between the total angular momentum and the line of sight + phiJL: Float + Difference between total and orbital angular momentum azimuthal angles + theta1: Float + Zenith angle between the spin and orbital angular momenta for the primary object + theta2: Float + Zenith angle between the spin and orbital angular momenta for the secondary object + phi12: Float + Difference between the azimuthal angles of the individual spin vector projections + onto the orbital plane + chi1: Float + Primary object aligned spin: + chi2: Float + Secondary object aligned spin: + M_c: Float + The chirp mass + eta: Float + The symmetric mass ratio + fRef: Float + The reference frequency + phiRef: Float + Binary phase at a reference frequency + + Returns: + ------- + iota: Float + Zenith angle between the orbital angular momentum and the line of sight + S1x: Float + The x-component of the primary spin + S1y: Float + The y-component of the primary spin + S1z: Float + The z-component of the primary spin + S2x: Float + The x-component of the secondary spin + S2y: Float + The y-component of the secondary spin + S2z: Float + The z-component of the secondary spin + """ + + def rotate_y(angle, vec): + """ + Rotate the vector (x, y, z) about y-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + def rotate_z(angle, vec): + """ + Rotate the vector (x, y, z) about z-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + LNh = jnp.array([0.0, 0.0, 1.0]) + + s1hat = jnp.array( + [ + jnp.sin(theta1) * jnp.cos(phiRef), + jnp.sin(theta1) * jnp.sin(phiRef), + jnp.cos(theta1), + ] + ) + s2hat = jnp.array( + [ + jnp.sin(theta2) * jnp.cos(phi12 + phiRef), + jnp.sin(theta2) * jnp.sin(phi12 + phiRef), + jnp.cos(theta2), + ] + ) + + m1, m2 = Mc_q_to_m1_m2(M_c, q) + eta = q / (1 + q) ** 2 + v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) + + Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) + s1 = m1 * m1 * chi1 * s1hat + s2 = m2 * m2 * chi2 * s2hat + J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) + + Jhat = J / jnp.linalg.norm(J) + theta0 = jnp.arccos(Jhat[2]) + phi0 = jnp.arctan2(Jhat[1], Jhat[0]) + + # Rotation 1: + s1hat = rotate_z(-phi0, s1hat) + s2hat = rotate_z(-phi0, s2hat) + + # Rotation 2: + LNh = rotate_y(-theta0, LNh) + s1hat = rotate_y(-theta0, s1hat) + s2hat = rotate_y(-theta0, s2hat) + + # Rotation 3: + LNh = rotate_z(phiJL - jnp.pi, LNh) + s1hat = rotate_z(phiJL - jnp.pi, s1hat) + s2hat = rotate_z(phiJL - jnp.pi, s2hat) + + # Compute iota + N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) + iota = jnp.arccos(jnp.dot(N, LNh)) + + thetaLJ = jnp.arccos(LNh[2]) + phiL = jnp.arctan2(LNh[1], LNh[0]) + + # Rotation 4: + s1hat = rotate_z(-phiL, s1hat) + s2hat = rotate_z(-phiL, s2hat) + N = rotate_z(-phiL, N) + + # Rotation 5: + s1hat = rotate_y(-thetaLJ, s1hat) + s2hat = rotate_y(-thetaLJ, s2hat) + N = rotate_y(-thetaLJ, N) + + # Rotation 6: + phiN = jnp.arctan2(N[1], N[0]) + s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat) + s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat) + + S1 = s1hat * chi1 + S2 = s2hat * chi2 + return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] + + def zenith_azimuth_to_ra_dec( - zenith: Float, azimuth: Float, gmst: Float, delta_x: Float[Array, " 3"] + zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. @@ -433,8 +724,8 @@ def zenith_azimuth_to_ra_dec( Azimuthal angle. gmst : Float Greenwich mean sidereal time. - delta_x : Float - The vector pointing from the first detector to the second detector. + rotation : Float[Array, " 3 3"] + The rotation matrix. Copied and modified from bilby/gw/utils.py @@ -445,26 +736,62 @@ def zenith_azimuth_to_ra_dec( dec : Float Declination. """ - theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + theta, phi = angle_rotation(zenith, azimuth, rotation) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) - ra = ra % (2 * jnp.pi) return ra, dec -def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ - A numerically stable method to evaluate log of - a modified Bessel function of order 0. - It is used in the phase-marginalized likelihood. + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. Parameters - ========== - x: array-like - Value(s) at which to evaluate the function + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. Returns - ======= - array-like: - The natural logarithm of the bessel function + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + phi = ra - gmst + theta = jnp.pi / 2 - dec + phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) + return theta, phi + + +def ra_dec_to_zenith_azimuth( + ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the right ascension and declination to the zenith angle and azimuthal angle. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + rotation : Float[Array, " 3 3"] + The rotation matrix. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. """ - return jnp.log(i0e(x)) + x + theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) + zenith, azimuth = angle_rotation(theta, phi, rotation) + return zenith, azimuth diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py new file mode 100644 index 00000000..3ad51e62 --- /dev/null +++ b/src/jimgw/transforms.py @@ -0,0 +1,463 @@ +from abc import ABC +from typing import Callable + +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped + + +class Transform(ABC): + """ + Base class for transform. + The purpose of this class is purely for keeping name + """ + + name_mapping: tuple[list[str], list[str]] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + self.name_mapping = name_mapping + + def propagate_name(self, x: list[str]) -> list[str]: + input_set = set(x) + from_set = set(self.name_mapping[0]) + to_set = set(self.name_mapping[1]) + return list(input_set - from_set | to_set) + + +class NtoMTransform(Transform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + def forward(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Push forward the input x to transformed coordinate y. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + x_copy = x.copy() + output_params = self.transform_func(x_copy) + jax.tree.map( + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy + + +class NtoNTransform(NtoMTransform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + @property + def n_dim(self) -> int: + return len(self.name_mapping[0]) + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + """ + Transform the input x to transformed coordinate y and return the log Jacobian determinant. + This only works if the transform is a N -> N transform. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + log_det : Float + The log Jacobian determinant. + """ + x_copy = x.copy() + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jax.tree.map( + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy, jacobian + + +class BijectiveTransform(NtoNTransform): + + inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + """ + Inverse transform the input y to original coordinate x. + + Parameters + ---------- + y : dict[str, Float] + The transformed dictionary. + + Returns + ------- + x : dict[str, Float] + The original dictionary. + log_det : Float + The log Jacobian determinant. + """ + y_copy = y.copy() + transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) + output_params = self.inverse_transform_func(transform_params) + jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jax.tree.map( + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy, jacobian + + def backward(self, y: dict[str, Float]) -> dict[str, Float]: + """ + Pull back the input y to original coordinate x and return the log Jacobian determinant. + + Parameters + ---------- + y : dict[str, Float] + The transformed dictionary. + + Returns + ------- + x : dict[str, Float] + The original dictionary. + """ + y_copy = y.copy() + output_params = self.inverse_transform_func(y_copy) + jax.tree.map( + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy + + +@jaxtyped(typechecker=typechecker) +class ScaleTransform(BijectiveTransform): + scale: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + scale: Float, + ): + super().__init__(name_mapping) + self.scale = scale + self.transform_func = lambda x: { + name_mapping[1][i]: x[name_mapping[0][i]] * self.scale + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: x[name_mapping[1][i]] / self.scale + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class OffsetTransform(BijectiveTransform): + offset: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + offset: Float, + ): + super().__init__(name_mapping) + self.offset = offset + self.transform_func = lambda x: { + name_mapping[1][i]: x[name_mapping[0][i]] + self.offset + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: x[name_mapping[1][i]] - self.offset + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class LogitTransform(BijectiveTransform): + """ + Logit transform following + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: 1 / (1 + jnp.exp(-x[name_mapping[0][i]])) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.log( + x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]]) + ) + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class ArcSineTransform(BijectiveTransform): + """ + ArcSine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.arcsin(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.sin(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class BoundToBound(BijectiveTransform): + """ + Bound to bound transformation + """ + + original_lower_bound: Float[Array, " n_dim"] + original_upper_bound: Float[Array, " n_dim"] + target_lower_bound: Float[Array, " n_dim"] + target_upper_bound: Float[Array, " n_dim"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float[Array, " n_dim"], + original_upper_bound: Float[Array, " n_dim"], + target_lower_bound: Float[Array, " n_dim"], + target_upper_bound: Float[Array, " n_dim"], + ): + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + self.target_lower_bound = target_lower_bound + self.target_upper_bound = target_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: (x[name_mapping[0][i]] - self.original_lower_bound[i]) + * (self.target_upper_bound[i] - self.target_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) + + self.target_lower_bound[i] + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (x[name_mapping[1][i]] - self.target_lower_bound[i]) + * (self.original_upper_bound[i] - self.original_lower_bound[i]) + / (self.target_upper_bound[i] - self.target_lower_bound[i]) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class BoundToUnbound(BijectiveTransform): + """ + Bound to unbound transformation + """ + + original_lower_bound: Float + original_upper_bound: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, + original_upper_bound: Float, + ): + + def logit(x): + return jnp.log(x / (1 - x)) + + super().__init__(name_mapping) + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.original_upper_bound = jnp.atleast_1d(original_upper_bound) + + self.transform_func = lambda x: { + name_mapping[1][i]: logit( + (x[name_mapping[0][i]] - self.original_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) + ) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + self.original_upper_bound[i] - self.original_lower_bound[i] + ) + / (1 + jnp.exp(-x[name_mapping[1][i]])) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class SingleSidedUnboundTransform(BijectiveTransform): + """ + Unbound upper limit transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + +class PowerLawTransform(BijectiveTransform): + """ + PowerLaw transformation + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + xmin: Float + xmax: Float + alpha: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + alpha: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.alpha = alpha + self.transform_func = lambda x: { + name_mapping[1][i]: ( + self.xmin ** (1.0 + self.alpha) + + x[name_mapping[0][i]] + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + ** (1.0 / (1.0 + self.alpha)) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + ( + x[name_mapping[1][i]] ** (1.0 + self.alpha) + - self.xmin ** (1.0 + self.alpha) + ) + / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + for i in range(len(name_mapping[1])) + } + + +class ParetoTransform(BijectiveTransform): + """ + Pareto transformation: Power law when alpha = -1 + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.transform_func = lambda x: { + name_mapping[1][i]: self.xmin + * jnp.exp(x[name_mapping[0][i]] * jnp.log(self.xmax / self.xmin)) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + jnp.log(x[name_mapping[1][i]] / self.xmin) + / jnp.log(self.xmax / self.xmin) + ) + for i in range(len(name_mapping[1])) + } + + +def reverse_bijective_transform( + original_transform: BijectiveTransform, +) -> BijectiveTransform: + + reversed_name_mapping = ( + original_transform.name_mapping[1], + original_transform.name_mapping[0], + ) + reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping) + reversed_transform.transform_func = original_transform.inverse_transform_func + reversed_transform.inverse_transform_func = original_transform.transform_func + reversed_transform.__repr__ = lambda: f"Reversed{repr(original_transform)}" + + return reversed_transform diff --git a/src/jimgw/utils.py b/src/jimgw/utils.py new file mode 100644 index 00000000..70c6e166 --- /dev/null +++ b/src/jimgw/utils.py @@ -0,0 +1,37 @@ +import jax.numpy as jnp +from jax.scipy.special import i0e +from jaxtyping import Array, Float + +from jimgw.prior import Prior + + +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output + + +def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: + """ + A numerically stable method to evaluate log of + a modified Bessel function of order 0. + It is used in the phase-marginalized likelihood. + + Parameters + ========== + x: array-like + Value(s) at which to evaluate the function + + Returns + ======= + array-like: + The natural logarithm of the bessel function + """ + return jnp.log(i0e(x)) + x diff --git a/test/integration/.gitignore b/test/integration/.gitignore new file mode 100644 index 00000000..a7f7ef0e --- /dev/null +++ b/test/integration/.gitignore @@ -0,0 +1,2 @@ +outdir/ +figures/ diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py new file mode 100644 index 00000000..5103e5d8 --- /dev/null +++ b/test/integration/test_GW150914_D.py @@ -0,0 +1,134 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam +from flowMC.utils.postprocessing import plot_summary +import optax + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform, + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform, +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py new file mode 100644 index 00000000..cbac1788 --- /dev/null +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -0,0 +1,135 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform, + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform, +] + +likelihood = HeterodynedTransientLikelihoodFD( + ifos, + prior=prior, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_steps=5, + popsize=10, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py new file mode 100644 index 00000000..9892058d --- /dev/null +++ b/test/integration/test_GW150914_Pv2.py @@ -0,0 +1,142 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) +theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) +phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) +theta_1_prior = SinePrior(parameter_names=["theta_1"]) +theta_2_prior = SinePrior(parameter_names=["theta_2"]) +phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) +a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) +a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + Mc_prior, + q_prior, + theta_jn_prior, + phi_jl_prior, + theta_1_prior, + theta_2_prior, + phi_12_prior, + a_1_prior, + a_2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) +] + +likelihood_transforms = [ + SpinToCartesianSpinTransform(freq_ref=20.0), + MassRatioToSymmetricMassRatioTransform, +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(15) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_mass_transforms.py b/test/integration/test_mass_transforms.py new file mode 100644 index 00000000..65b95244 --- /dev/null +++ b/test/integration/test_mass_transforms.py @@ -0,0 +1,106 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + +import numpy as np +import matplotlib.pyplot as plt +import corner +import jax +import jax.numpy as jnp +from jaxtyping import Float + +from jimgw.prior import UniformPrior, CombinePrior +from jimgw.single_event.transforms import ChirpMassMassRatioToComponentMassesTransform +from jimgw.base import LikelihoodBase +from jimgw.jim import Jim + +params = {"axes.grid": True, + "text.usetex" : True, + "font.family" : "serif", + "ytick.color" : "black", + "xtick.color" : "black", + "axes.labelcolor" : "black", + "axes.edgecolor" : "black", + "font.serif" : ["Computer Modern Serif"], + "xtick.labelsize": 16, + "ytick.labelsize": 16, + "axes.labelsize": 16, + "legend.fontsize": 16, + "legend.title_fontsize": 16, + "figure.titlesize": 16} + +plt.rcParams.update(params) + +# Improved corner kwargs +default_corner_kwargs = dict(bins=40, + smooth=1., + show_titles=False, + label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), + color="blue", + # quantiles=[], + # levels=[0.9], + plot_density=True, + plot_datapoints=False, + fill_contours=True, + max_n_ticks=4, + min_n_ticks=3, + truth_color = "red", + save=False) + +# Likelihood for this test: + +class MyLikelihood(LikelihoodBase): + """Simple toy likelihood: Gaussian centered on the true component masses""" + + true_m1: Float + true_m2: Float + + def __init__(self, + true_m1: Float, + true_m2: Float): + + self.true_m1 = true_m1 + self.true_m2 = true_m2 + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + m1, m2 = params['m_1'], params['m_2'] + m1_std = 0.1 + m2_std = 0.1 + return -0.5 * (((m1 - self.true_m1) / m1_std)**2 + ((m2 - self.true_m2) / m2_std)**2) + +# Setup +true_m1 = 1.6 +true_m2 = 1.4 +true_mc = (true_m1 * true_m2)**(3/5) / (true_m1 + true_m2)**(1/5) +true_q = true_m2 / true_m1 + +# Priors +eps = 0.5 # half of width of the chirp mass prior +mc_prior = UniformPrior(true_mc - eps, true_mc + eps, parameter_names=['M_c']) +q_prior = UniformPrior(0.125, 1.0, parameter_names=['q']) +combine_prior = CombinePrior([mc_prior, q_prior]) + +# Likelihood and transform +likelihood = MyLikelihood(true_m1, true_m2) +mass_transform = ChirpMassMassRatioToComponentMassesTransform + +print("Checking mass_transform repr") +print(repr(mass_transform)) + +# Other stuff we have to give to Jim to make it work +step = 5e-3 +local_sampler_arg = {"step_size": step * jnp.eye(2)} + +# Jim: +jim = Jim(likelihood, + combine_prior, + likelihood_transforms=[mass_transform], + n_chains = 10, + parameter_names=['M_c', 'q'], + n_loop_training=2, + n_loop_production=2, + local_sampler_arg=local_sampler_arg) + +jim.sample(jax.random.PRNGKey(0)) +jim.print_summary() \ No newline at end of file diff --git a/test/test_prior.py b/test/test_prior.py deleted file mode 100644 index 8b137891..00000000 --- a/test/test_prior.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test/test_transform.py b/test/test_transform.py deleted file mode 100644 index 5d9fe464..00000000 --- a/test/test_transform.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np -import jax.numpy as jnp - -class TestTransform: - def test_sky_location_transform(self): - from bilby.gw.utils import zenith_azimuth_to_ra_dec as bilby_earth_to_sky - from bilby.gw.detector.networks import InterferometerList - - from jimgw.single_event.utils import zenith_azimuth_to_ra_dec as jimgw_earth_to_sky - from jimgw.single_event.detector import detector_preset - from astropy.time import Time - - ifos = ["H1", "L1"] - geocent_time = 1000000000 - - import matplotlib.pyplot as plt - - for zenith in np.linspace(0, np.pi, 10): - for azimuth in np.linspace(0, 2*np.pi, 10): - bilby_sky_location = np.array(bilby_earth_to_sky(zenith, azimuth, geocent_time, InterferometerList(ifos))) - jimgw_sky_location = np.array(jimgw_earth_to_sky(zenith, azimuth, Time(geocent_time, format="gps").sidereal_time("apparent", "greenwich").rad, detector_preset[ifos[0]].vertex - detector_preset[ifos[1]].vertex)) - assert np.allclose(bilby_sky_location, jimgw_sky_location, atol=1e-4) - - def test_spin_transform(self): - from bilby.gw.conversion import bilby_to_lalsimulation_spins as bilby_spin_transform - from bilby.gw.conversion import symmetric_mass_ratio_to_mass_ratio, chirp_mass_and_mass_ratio_to_component_masses - - from jimgw.single_event.utils import spin_to_cartesian_spin as jimgw_spin_transform - - for _ in range(100): - thetaJN = jnp.array(np.random.uniform(0, np.pi)) - phiJL = jnp.array(np.random.uniform(0, np.pi)) - theta1 = jnp.array(np.random.uniform(0, np.pi)) - theta2 = jnp.array(np.random.uniform(0, np.pi)) - phi12 = jnp.array(np.random.uniform(0, np.pi)) - chi1 = jnp.array(np.random.uniform(0, 1)) - chi2 = jnp.array(np.random.uniform(0, 1)) - M_c = jnp.array(np.random.uniform(1, 100)) - eta = jnp.array(np.random.uniform(0.1, 0.25)) - fRef = jnp.array(np.random.uniform(10, 1000)) - phiRef = jnp.array(np.random.uniform(0, 2*np.pi)) - - q = symmetric_mass_ratio_to_mass_ratio(eta) - m1, m2 = chirp_mass_and_mass_ratio_to_component_masses(M_c, q) - MsunInkg = 1.9884e30 - bilby_spin = jnp.array(bilby_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, m1*MsunInkg, m2*MsunInkg, fRef, phiRef)) - jimgw_spin = jnp.array(jimgw_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, M_c, eta, fRef, phiRef)) - assert np.allclose(bilby_spin, jimgw_spin, atol=1e-4) diff --git a/test/test_detector.py b/test/unit/test_detector.py similarity index 100% rename from test/test_detector.py rename to test/unit/test_detector.py diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py new file mode 100644 index 00000000..852ded16 --- /dev/null +++ b/test/unit/test_prior.py @@ -0,0 +1,113 @@ +from jimgw.prior import * +from jimgw.utils import trace_prior_parent +import scipy.stats as stats + + +class TestUnivariatePrior: + def test_logistic(self): + p = LogisticDistribution(["x"]) + # Check that the log_prob is finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Cross-check log_prob with scipy.stats.logistic + x = jnp.linspace(-10.0, 10.0, 1000) + assert jnp.allclose(jax.vmap(p.log_prob)(p.add_name(x[None])), stats.logistic.logpdf(x)) + + def test_standard_normal(self): + p = StandardNormalDistribution(["x"]) + # Check that the log_prob is finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Cross-check log_prob with scipy.stats.norm + x = jnp.linspace(-10.0, 10.0, 1000) + assert jnp.allclose(jax.vmap(p.log_prob)(p.add_name(x[None])), stats.norm.logpdf(x)) + + def test_uniform(self): + xmin, xmax = -10.0, 10.0 + p = UniformPrior(xmin, xmax, ["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) + # Check that the log_prob is correct in the support + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin))) + + def test_sine(self): + p = SinePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) + # Check that the log_prob is finite + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Check that the log_prob is correct in the support + x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) + y = jax.vmap(p.base_prior.base_prior.transform)(x) + y = jax.vmap(p.base_prior.transform)(y) + y = jax.vmap(p.transform)(y) + assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.sin(y['x'])/2.0)) + + def test_cosine(self): + p = CosinePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) + # Check that the log_prob is finite + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Check that the log_prob is correct in the support + x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) + y = jax.vmap(p.base_prior.transform)(x) + y = jax.vmap(p.transform)(y) + assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.cos(y['x'])/2.0)) + + def test_uniform_sphere(self): + p = UniformSpherePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x_mag'])) + assert jnp.all(jnp.isfinite(samples['x_theta'])) + assert jnp.all(jnp.isfinite(samples['x_phi'])) + # Check that the log_prob is finite + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + + + def test_power_law(self): + def powerlaw_log_pdf(x, alpha, xmin, xmax): + if alpha == -1.0: + normalization = 1./(jnp.log(xmax) - jnp.log(xmin)) + else: + normalization = (1.0 + alpha) / (xmax**(1.0 + alpha) - xmin**(1.0 + alpha)) + return jnp.log(normalization) + alpha * jnp.log(x) + + def func(alpha): + xmin = 0.05 + xmax = 10.0 + p = PowerLawPrior(xmin, xmax, alpha, ["x"]) + # Check that all the samples are finite + powerlaw_samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) + + # Check that all the log_probs are finite + log_p = jax.vmap(p.log_prob, [0])(powerlaw_samples) + assert jnp.all(jnp.isfinite(log_p)) + + # Check that the log_prob is correct in the support + log_prob = jax.vmap(p.log_prob)(powerlaw_samples) + standard_log_prob = powerlaw_log_pdf(powerlaw_samples['x'], alpha, xmin, xmax) + # log pdf of powerlaw + assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4) + + # Test Pareto Transform + func(-1.0) + # Test other values of alpha + print("Testing PowerLawPrior") + positive_alpha = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] + for alpha_val in positive_alpha: + func(alpha_val) + negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] + for alpha_val in negative_alpha: + func(alpha_val) \ No newline at end of file diff --git a/test/unit/test_transform.py b/test/unit/test_transform.py new file mode 100644 index 00000000..14162f79 --- /dev/null +++ b/test/unit/test_transform.py @@ -0,0 +1,48 @@ +# import numpy as np +# import jax.numpy as jnp + +# class TestTransform: +# def test_sky_location_transform(self): +# from bilby.gw.utils import zenith_azimuth_to_ra_dec as bilby_earth_to_sky +# from bilby.gw.detector.networks import InterferometerList + +# from jimgw.single_event.utils import zenith_azimuth_to_ra_dec as jimgw_earth_to_sky +# from jimgw.single_event.detector import detector_preset +# from astropy.time import Time + +# ifos = ["H1", "L1"] +# geocent_time = 1000000000 + +# import matplotlib.pyplot as plt + +# for zenith in np.linspace(0, np.pi, 10): +# for azimuth in np.linspace(0, 2*np.pi, 10): +# bilby_sky_location = np.array(bilby_earth_to_sky(zenith, azimuth, geocent_time, InterferometerList(ifos))) +# jimgw_sky_location = np.array(jimgw_earth_to_sky(zenith, azimuth, Time(geocent_time, format="gps").sidereal_time("apparent", "greenwich").rad, detector_preset[ifos[0]].vertex - detector_preset[ifos[1]].vertex)) +# assert np.allclose(bilby_sky_location, jimgw_sky_location, atol=1e-4) + +# def test_spin_transform(self): +# from bilby.gw.conversion import bilby_to_lalsimulation_spins as bilby_spin_transform +# from bilby.gw.conversion import symmetric_mass_ratio_to_mass_ratio, chirp_mass_and_mass_ratio_to_component_masses + +# from jimgw.single_event.utils import spin_to_cartesian_spin as jimgw_spin_transform + +# for _ in range(100): +# thetaJN = jnp.array(np.random.uniform(0, np.pi)) +# phiJL = jnp.array(np.random.uniform(0, np.pi)) +# theta1 = jnp.array(np.random.uniform(0, np.pi)) +# theta2 = jnp.array(np.random.uniform(0, np.pi)) +# phi12 = jnp.array(np.random.uniform(0, np.pi)) +# chi1 = jnp.array(np.random.uniform(0, 1)) +# chi2 = jnp.array(np.random.uniform(0, 1)) +# M_c = jnp.array(np.random.uniform(1, 100)) +# eta = jnp.array(np.random.uniform(0.1, 0.25)) +# fRef = jnp.array(np.random.uniform(10, 1000)) +# phiRef = jnp.array(np.random.uniform(0, 2*np.pi)) + +# q = symmetric_mass_ratio_to_mass_ratio(eta) +# m1, m2 = chirp_mass_and_mass_ratio_to_component_masses(M_c, q) +# MsunInkg = 1.9884e30 +# bilby_spin = jnp.array(bilby_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, m1*MsunInkg, m2*MsunInkg, fRef, phiRef)) +# jimgw_spin = jnp.array(jimgw_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, M_c, eta, fRef, phiRef)) +# assert np.allclose(bilby_spin, jimgw_spin, atol=1e-4)