diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 4806efed..d70c2ae1 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -3,11 +3,7 @@ name: Python package -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] +on: [push, pull_request] jobs: build: diff --git a/docs/tutorials/naming_system.md b/docs/tutorials/naming_system.md new file mode 100644 index 00000000..e69de29b diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index c0878afc..88ffe52b 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -31,7 +31,6 @@ run = SingleEventRun( seed=0, - path="test_data/GW150914/", detectors=["H1", "L1"], priors={ "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0}, diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 0ab1105a..74f65efc 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,17 +8,51 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior +from jimgw.transforms import BijectiveTransform, NtoMTransform class Jim(object): """ Master class for interfacing with flowMC - """ - def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): - self.Likelihood = likelihood - self.Prior = prior + likelihood: LikelihoodBase + prior: Prior + + # Name of parameters to sample from + sample_transforms: list[BijectiveTransform] + likelihood_transforms: list[NtoMTransform] + parameter_names: list[str] + sampler: Sampler + + def __init__( + self, + likelihood: LikelihoodBase, + prior: Prior, + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], + **kwargs, + ): + self.likelihood = likelihood + self.prior = prior + + self.sample_transforms = sample_transforms + self.likelihood_transforms = likelihood_transforms + self.parameter_names = prior.parameter_names + + if len(sample_transforms) == 0: + print( + "No sample transforms provided. Using prior parameters as sampling parameters" + ) + else: + print("Using sample transforms") + for transform in sample_transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) + + if len(likelihood_transforms) == 0: + print( + "No likelihood transforms provided. Using prior parameters as likelihood parameters" + ) seed = kwargs.get("seed", 0) @@ -33,11 +67,11 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): rng_key, subkey = jax.random.split(rng_key) model = MaskedCouplingRQSpline( - self.Prior.n_dim, num_layers, hidden_size, num_bins, subkey + self.prior.n_dim, num_layers, hidden_size, num_bins, subkey ) - self.Sampler = Sampler( - self.Prior.n_dim, + self.sampler = Sampler( + self.prior.n_dim, rng_key, None, # type: ignore local_sampler, @@ -45,18 +79,38 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): **kwargs, ) + def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: + """ + Turn an array into a dictionary + + Parameters + ---------- + x : Array + An array of parameters. Shape (n_dim,). + """ + + return dict(zip(self.parameter_names, x)) + def posterior(self, params: Float[Array, " n_dim"], data: dict): - prior_params = self.Prior.add_name(params.T) - prior = self.Prior.log_prob(prior_params) + named_params = self.add_name(params) + transform_jacobian = 0.0 + for transform in self.sample_transforms: + named_params, jacobian = transform.inverse(named_params) + transform_jacobian += jacobian + prior = self.prior.log_prob(named_params) + transform_jacobian + for transform in self.likelihood_transforms: + named_params = transform.forward(named_params) return ( - self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior + self.likelihood.evaluate(named_params, data) + prior ) def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: - initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains) + initial_guess_named = self.prior.sample(key, self.sampler.n_chains) + for transform in self.sample_transforms: + initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T - self.Sampler.sample(initial_guess, None) # type: ignore + self.sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( self, @@ -67,7 +121,7 @@ def maximize_likelihood( ): key = jax.random.PRNGKey(seed) set_nwalkers = set_nwalkers - initial_guess = self.Prior.sample(key, set_nwalkers) + initial_guess = self.prior.sample(key, set_nwalkers) def negative_posterior(x: Float[Array, " n_dim"]): return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet @@ -78,33 +132,59 @@ def negative_posterior(x: Float[Array, " n_dim"]): print("Done compiling") print("Starting the optimizer") - optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True) + optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True) _ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return best_fit - def print_summary(self, transform: bool = True): + def print_summary(self): """ Generate summary of the run """ - train_summary = self.Sampler.get_sampler_state(training=True) - production_summary = self.Sampler.get_sampler_state(training=False) + train_summary = self.sampler.get_sampler_state(training=True) + production_summary = self.sampler.get_sampler_state(training=False) - training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T - training_chain = self.Prior.add_name(training_chain) - if transform: - training_chain = self.Prior.transform(training_chain) + training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(training_chain[0]) + for transform in self.sample_transforms: + named_sample = transform.backward(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in training_chain[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.backward(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + training_chain = transformed_chain + else: + training_chain = self.add_name(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T - production_chain = self.Prior.add_name(production_chain) - if transform: - production_chain = self.Prior.transform(production_chain) + production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(production_chain[0]) + for transform in self.sample_transforms: + named_sample = transform.backward(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in production_chain[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.backward(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + production_chain = transformed_chain + else: + production_chain = self.add_name(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -112,7 +192,7 @@ def print_summary(self, transform: bool = True): print("Training summary") print("=" * 10) for key, value in training_chain.items(): - print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") + print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}") print( f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" ) @@ -129,7 +209,7 @@ def print_summary(self, transform: bool = True): print("Production summary") print("=" * 10) for key, value in production_chain.items(): - print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") + print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}") print( f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" ) @@ -156,12 +236,32 @@ def get_samples(self, training: bool = False) -> dict: """ if training: - chains = self.Sampler.get_sampler_state(training=True)["chains"] + chains = self.sampler.get_sampler_state(training=True)["chains"] + else: + chains = self.sampler.get_sampler_state(training=False)["chains"] + + # Need rewrite to output chains instead of flattened samples + chains = chains.reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(chains[0]) + for transform in self.sample_transforms: + named_sample = transform.backward(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in chains[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.backward(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + output = transformed_chain else: - chains = self.Sampler.get_sampler_state(training=False)["chains"] + output = self.add_name(chains) - chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1))) - return chains + for key in output.keys(): + output[key] = jnp.array(output[key]) + return output def plot(self): pass diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f1827753..031a4133 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,14 +1,18 @@ from dataclasses import field -from typing import Callable, Union import jax import jax.numpy as jnp -from flowMC.nfmodel.base import Distribution -from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped from beartype import beartype as typechecker -from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.single_event.detector import GroundBased2G, detector_preset -from astropy.time import Time +from flowMC.nfmodel.base import Distribution +from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped + +from jimgw.transforms import ( + BijectiveTransform, + LogitTransform, + ScaleTransform, + OffsetTransform, + ArcSineTransform, +) class Prior(Distribution): @@ -21,58 +25,21 @@ class Prior(Distribution): the names of the parameters and the transforms that are applied to them. """ - naming: list[str] - transforms: dict[str, tuple[str, Callable]] = field(default_factory=dict) + parameter_names: list[str] + composite: bool = False @property - def n_dim(self): - return len(self.naming) + def n_dim(self) -> int: + return len(self.parameter_names) - def __init__( - self, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {} - ): + def __init__(self, parameter_names: list[str]): """ Parameters ---------- - naming : list[str] + parameter_names : list[str] A list of names for the parameters of the prior. - transforms : dict[tuple[str,Callable]] - A dictionary of transforms to apply to the parameters. The keys are - the names of the parameters and the values are a tuple of the name - of the transform and the transform itself. """ - self.naming = naming - self.transforms = {} - - def make_lambda(name): - return lambda x: x[name] - - for name in naming: - if name in transforms: - self.transforms[name] = transforms[name] - else: - # Without the function, the lambda will refer to the variable name instead of its value, - # which will make lambda reference the last value of the variable name - self.transforms[name] = (name, make_lambda(name)) - - def transform(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Apply the transforms to the parameters. - - Parameters - ---------- - x : dict - A dictionary of parameters. Names should match the ones in the prior. - - Returns - ------- - x : dict - A dictionary of parameters with the transforms applied. - """ - output = {} - for value in self.transforms.values(): - output[value[0]] = value[1](x) - return output + self.parameter_names = parameter_names def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ @@ -84,43 +51,33 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: An array of parameters. Shape (n_dim,). """ - return dict(zip(self.naming, x)) + return dict(zip(self.parameter_names, x)) def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: raise NotImplementedError - def log_prob(self, x: dict[str, Array]) -> Float: + def log_prob(self, z: dict[str, Array]) -> Float: raise NotImplementedError @jaxtyped(typechecker=typechecker) -class Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 +class LogisticDistribution(Prior): def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax})" + return f"LogisticDistribution(parameter_names={self.parameter_names})" - def __init__( - self, - xmin: Float, - xmax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Uniform needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin + def __init__(self, parameter_names: list[str], **kwargs): + super().__init__(parameter_names) + self.composite = False + assert self.n_dim == 1, "LogisticDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: """ - Sample from a uniform distribution. + Sample from a logistic distribution. Parameters ---------- @@ -135,72 +92,33 @@ def sample( Samples from the distribution. The keys are the names of the parameters. """ - samples = jax.random.uniform( - rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax - ) + samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) + samples = jnp.log(samples / (1 - samples)) return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - return output + jnp.log(1.0 / (self.xmax - self.xmin)) + def log_prob(self, z: dict[str, Float]) -> Float: + variable = z[self.parameter_names[0]] + return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) @jaxtyped(typechecker=typechecker) -class Unconstrained_Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 +class StandardNormalDistribution(Prior): def __repr__(self): - return f"Unconstrained_Uniform(xmin={self.xmin}, xmax={self.xmax})" - - def __init__( - self, - xmin: Float, - xmax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - local_transform = self.transforms + return f"StandardNormalDistribution(parameter_names={self.parameter_names})" - def new_transform(param): - param[self.naming[0]] = self.to_range(param[self.naming[0]]) - return local_transform[self.naming[0]][1](param) - - self.transforms = { - self.naming[0]: (local_transform[self.naming[0]][0], new_transform) - } - - def to_range(self, x: Float) -> Float: - """ - Transform the parameters to the range of the prior. - - Parameters - ---------- - x : Float - The parameters to transform. - - Returns - ------- - x : dict - A dictionary of parameters with the transforms applied. - """ - return (self.xmax - self.xmin) / (1 + jnp.exp(-x)) + self.xmin + def __init__(self, parameter_names: list[str], **kwargs): + super().__init__(parameter_names) + self.composite = False + assert ( + self.n_dim == 1 + ), "StandardNormalDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: """ - Sample from a uniform distribution. + Sample from a standard normal distribution. Parameters ---------- @@ -211,513 +129,569 @@ def sample( Returns ------- - samples : - An array of shape (n_samples, n_dim) containing the samples. + samples : dict + Samples from the distribution. The keys are the names of the parameters. """ - samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) - samples = jnp.log(samples / (1 - samples)) + samples = jax.random.normal(rng_key, (n_samples,)) return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) + def log_prob(self, z: dict[str, Float]) -> Float: + variable = z[self.parameter_names[0]] + return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) -class Sphere(Prior): +class SequentialTransformPrior(Prior): """ - A prior on a sphere represented by Cartesian coordinates. - - Magnitude is sampled from a uniform distribution. + Transform a prior distribution by applying a sequence of transforms. + The space before the transform is named as x, + and the space after the transform is named as z """ - def __repr__(self): - return f"Sphere(naming={self.naming})" - - def __init__(self, naming: list[str], **kwargs): - name = naming[0] - self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] - self.transforms = { - self.naming[0]: ( - f"{naming}_x", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.cos(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[1]: ( - f"{naming}_y", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.sin(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[2]: ( - f"{naming}_z", - lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 3) - theta = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) - mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) - return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - theta = x[self.naming[0]] - phi = x[self.naming[1]] - mag = x[self.naming[2]] - output = jnp.where( - (mag > 1) - | (mag < 0) - | (phi > 2 * jnp.pi) - | (phi < 0) - | (theta > jnp.pi) - | (theta < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), - ) - return output - - -@jaxtyped(typechecker=typechecker) -class AlignedSpin(Prior): - """ - Prior distribution for the aligned (z) component of the spin. - - This assume the prior distribution on the spin magnitude to be uniform in [0, amax] - with its orientation uniform on a sphere - - p(chi) = -log(|chi| / amax) / 2 / amax - - This is useful when comparing results between an aligned-spin run and - a precessing spin run. - - See (A7) of https://arxiv.org/abs/1805.10457. - """ - - amax: Float = 0.99 - chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + base_prior: Prior + transforms: list[BijectiveTransform] def __repr__(self): - return f"Alignedspin(amax={self.amax}, naming={self.naming})" + return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" def __init__( self, - amax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, + base_prior: Prior, + transforms: list[BijectiveTransform], ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" - self.amax = amax - # build the interpolation table for the ppf of the one-sided distribution - chi_axis = jnp.linspace(1e-31, self.amax, num=1000) - cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax - self.chi_axis = chi_axis - self.cdf_vals = cdf_vals - - @property - def xmin(self): - return -self.amax - - @property - def xmax(self): - return self.amax + self.base_prior = base_prior + self.transforms = transforms + self.parameter_names = base_prior.parameter_names + for transform in transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) + self.composite = True def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from the Alignedspin distribution. - - for chi > 0; - p(chi) = -log(chi / amax) / amax # halved normalization constant - cdf(chi) = -chi * (log(chi / amax) - 1) / amax - - Since there is a pole at chi=0, we will sample with the following steps - 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise - 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) - 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) - 3. Map the quantile to chi via the ppf by checking against the table - built during the initialization - 4. add back the sign - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. + output = self.base_prior.sample(rng_key, n_samples) + return jax.vmap(self.transform)(output) + def log_prob(self, z: dict[str, Float]) -> Float: """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - # 1. calculate the sign of chi from the q_samples - sign_samples = jnp.where( - q_samples >= 0.5, - jnp.zeros_like(q_samples) + 1.0, - jnp.zeros_like(q_samples) - 1.0, - ) - # 2. remap q_samples - q_samples = jnp.where( - q_samples >= 0.5, - 2 * (q_samples - 0.5), - 2 * (0.5 - q_samples), - ) - # 3. map the quantile to chi via interpolation - samples = jnp.interp( - q_samples, - self.cdf_vals, - self.chi_axis, - ) - # 4. add back the sign - samples *= sign_samples - - return self.add_name(samples[None]) + Evaluating the probability of the transformed variable z. + This is what flowMC should sample from + """ + output = 0 + for transform in reversed(self.transforms): + z, log_jacobian = transform.inverse(z) + output += log_jacobian + output += self.base_prior.log_prob(z) + return output - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_p = jnp.where( - (variable >= self.amax) | (variable <= -self.amax), - jnp.zeros_like(variable) - jnp.inf, - jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), - ) - return log_p + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + for transform in self.transforms: + x = transform.forward(x) + return x -@jaxtyped(typechecker=typechecker) -class EarthFrame(Prior): +class CombinePrior(Prior): """ - Prior distribution for sky location in Earth frame. + A prior class constructed by joinning multiple priors together to form a multivariate prior. + This assumes the priors composing the Combine class are independent. """ - ifos: list = field(default_factory=list) - gmst: float = 0.0 - delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) + base_prior: list[Prior] = field(default_factory=list) def __repr__(self): - return f"EarthFrame(naming={self.naming})" - - def __init__(self, gps: Float, ifos: list, **kwargs): - self.naming = ["zenith", "azimuth"] - if len(ifos) < 2: - return ValueError( - "At least two detectors are needed to define the Earth frame" - ) - elif isinstance(ifos[0], str): - self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] - elif isinstance(ifos[0], GroundBased2G): - self.ifos = ifos[:1] - else: - return ValueError( - "ifos should be a list of detector names or GroundBased2G objects" - ) - self.gmst = float( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + return ( + f"Combine(priors={self.base_prior}, parameter_names={self.parameter_names})" ) - self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - - self.transforms = { - "azimuth": ( - "ra", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[0], - ), - "zenith": ( - "dec", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[1], - ), - } + + def __init__( + self, + priors: list[Prior], + ): + parameter_names = [] + for prior in priors: + parameter_names += prior.parameter_names + self.base_prior = priors + self.parameter_names = parameter_names + self.composite = True def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 2) - zenith = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - azimuth = jax.random.uniform( - rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi - ) - return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - zenith = x["zenith"] - azimuth = x["azimuth"] - output = jnp.where( - (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.zeros_like(0), - ) - return output + jnp.log(jnp.sin(zenith)) + output = {} + for prior in self.base_prior: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output + def log_prob(self, z: dict[str, Float]) -> Float: + output = 0.0 + for prior in self.base_prior: + output += prior.log_prob(z) + return output -@jaxtyped(typechecker=typechecker) -class PowerLaw(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ x^{\alpha} - """ - xmin: float = 0.0 - xmax: float = 1.0 - alpha: float = 0.0 - normalization: float = 1.0 +@jaxtyped(typechecker=typechecker) +class UniformPrior(SequentialTransformPrior): + xmin: float + xmax: float def __repr__(self): - return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + return f"UniformPrior(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" def __init__( self, xmin: float, xmax: float, - alpha: Union[Int, float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, + parameter_names: list[str], ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin > 0.0, "With negative alpha, xmin must > 0" - assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformPrior needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - self.alpha = alpha - if alpha == -1: - self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) - else: - self.normalization = (1 + self.alpha) / ( - self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a power-law distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - if self.alpha == -1: - samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) - else: - samples = ( - self.xmin ** (1.0 + self.alpha) - + q_samples - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), + super().__init__( + LogisticDistribution([f"{self.parameter_names[0]}_base"]), + [ + LogitTransform( + ( + [f"{self.parameter_names[0]}_base"], + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + ) + ), + ScaleTransform( + ( + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + [f"{self.parameter_names[0]}-({xmin})"], + ), + xmax - xmin, + ), + OffsetTransform( + ([f"{self.parameter_names[0]}-({xmin})"], self.parameter_names), + xmin, + ), + ], ) - log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) - return log_p + log_in_range @jaxtyped(typechecker=typechecker) -class Exponential(Prior): +class SinePrior(SequentialTransformPrior): """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ exp(\alpha x) + A prior distribution where the pdf is proportional to sin(x) in the range [0, pi]. """ - xmin: float = 0.0 - xmax: float = jnp.inf - alpha: float = -1.0 - normalization: float = 1.0 - def __repr__(self): - return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + return f"SinePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 1, "SinePrior needs to be 1D distributions" + super().__init__( + CosinePrior([f"{self.parameter_names[0]}-pi/2"]), + [ + OffsetTransform( + ( + ( + [f"{self.parameter_names[0]}-pi/2"], + [f"{self.parameter_names[0]}"], + ) + ), + jnp.pi / 2, + ) + ], + ) - def __init__( - self, - xmin: Float, - xmax: Float, - alpha: Union[Int, Float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin != -jnp.inf, "With negative alpha, xmin must finite" - if alpha > 0.0: - assert xmax != jnp.inf, "With positive alpha, xmax must finite" - assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" - assert self.n_dim == 1, "Exponential needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha +@jaxtyped(typechecker=typechecker) +class CosinePrior(SequentialTransformPrior): + """ + A prior distribution where the pdf is proportional to cos(x) in the range [-pi/2, pi/2]. + """ - self.normalization = self.alpha / ( - jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) + def __repr__(self): + return f"CosinePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" + super().__init__( + UniformPrior(-1.0, 1.0, [f"sin({self.parameter_names[0]})"]), + [ + ArcSineTransform( + ([f"sin({self.parameter_names[0]})"], [self.parameter_names[0]]) + ) + ], ) - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a exponential distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - samples = ( - self.xmin - + jnp.log1p( - q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) - ) - / self.alpha - ) - return self.add_name(samples[None]) +@jaxtyped(typechecker=typechecker) +class UniformSpherePrior(CombinePrior): - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), + def __repr__(self): + return f"UniformSpherePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str], **kwargs): + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" + self.parameter_names = [ + f"{self.parameter_names[0]}_mag", + f"{self.parameter_names[0]}_theta", + f"{self.parameter_names[0]}_phi", + ] + super().__init__( + [ + UniformPrior(0.0, 1.0, [self.parameter_names[0]]), + SinePrior([self.parameter_names[1]]), + UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + ] ) - log_p = self.alpha * variable + jnp.log(self.normalization) - return log_p + log_in_range @jaxtyped(typechecker=typechecker) -class Normal(Prior): - mean: Float = 0.0 - std: Float = 1.0 +class PowerLawPrior(SequentialTransformPrior): + xmin: float + xmax: float + alpha: float def __repr__(self): - return f"Normal(mean={self.mean}, std={self.std})" + return f"PowerLawPrior(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" def __init__( self, - mean: Float, - std: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, + xmin: float, + xmax: float, + alpha: float, + parameter_names: list[str], ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Normal needs to be 1D distributions" - self.mean = mean - self.std = std - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a normal distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - samples = jax.random.normal(rng_key, (n_samples,)) - samples = self.mean + samples * self.std - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = ( - -0.5 * jnp.log(2 * jnp.pi) - - jnp.log(self.std) - - 0.5 * ((variable - self.mean) / self.std) ** 2 + self.parameter_names = parameter_names + assert self.n_dim == 1, "Power law needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + assert self.xmin < self.xmax, "xmin must be less than xmax" + assert self.xmin > 0.0, "x must be positive" + if self.alpha == -1.0: + transform = ParetoTransform( + ([f"{self.parameter_names[0]}_before_transform"], self.parameter_names), + xmin, + xmax, + ) + else: + transform = PowerLawTransform( + ([f"{self.parameter_names[0]}_before_transform"], self.parameter_names), + xmin, + xmax, + alpha, + ) + super().__init__( + LogisticDistribution([f"{self.parameter_names[0]}_base"]), + [ + LogitTransform( + ( + [f"{self.parameter_names[0]}_base"], + [f"{self.parameter_names[0]}_before_transform"], + ) + ), + transform, + ], ) - return output -class Composite(Prior): - priors: list[Prior] = field(default_factory=list) - - def __repr__(self): - return f"Composite(priors={self.priors}, naming={self.naming})" - - def __init__( - self, - priors: list[Prior], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - naming = [] - self.transforms = {} - for prior in priors: - naming += prior.naming - self.transforms.update(prior.transforms) - self.priors = priors - self.naming = naming - self.transforms.update(transforms) +@jaxtyped(typechecker=typechecker) +class UniformComponentChirpMassPrior(PowerLawPrior): + """ + A prior in the range [xmin, xmax) for chirp mass which assumes the + component masses to be uniformly distributed. - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - output = {} - for prior in self.priors: - rng_key, subkey = jax.random.split(rng_key) - output.update(prior.sample(subkey, n_samples)) - return output + p(M_c) ~ M_c + """ - def log_prob(self, x: dict[str, Float]) -> Float: - output = 0.0 - for prior in self.priors: - output += prior.log_prob(x) - return output + def __repr__(self): + return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float): + super().__init__(xmin, xmax, 1.0, ["M_c"]) + + +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output + + +# ====================== Things below may need rework ====================== + + +# @jaxtyped(typechecker=typechecker) +# class AlignedSpin(Prior): +# """ +# Prior distribution for the aligned (z) component of the spin. + +# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] +# with its orientation uniform on a sphere + +# p(chi) = -log(|chi| / amax) / 2 / amax + +# This is useful when comparing results between an aligned-spin run and +# a precessing spin run. + +# See (A7) of https://arxiv.org/abs/1805.10457. +# """ + +# amax: Float = 0.99 +# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) +# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + +# def __repr__(self): +# return f"Alignedspin(amax={self.amax}, naming={self.naming})" + +# def __init__( +# self, +# amax: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" +# self.amax = amax + +# # build the interpolation table for the ppf of the one-sided distribution +# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) +# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax +# self.chi_axis = chi_axis +# self.cdf_vals = cdf_vals + +# @property +# def xmin(self): +# return -self.amax + +# @property +# def xmax(self): +# return self.amax + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from the Alignedspin distribution. + +# for chi > 0; +# p(chi) = -log(chi / amax) / amax # halved normalization constant +# cdf(chi) = -chi * (log(chi / amax) - 1) / amax + +# Since there is a pole at chi=0, we will sample with the following steps +# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise +# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) +# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) +# 3. Map the quantile to chi via the ppf by checking against the table +# built during the initialization +# 4. add back the sign + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# # 1. calculate the sign of chi from the q_samples +# sign_samples = jnp.where( +# q_samples >= 0.5, +# jnp.zeros_like(q_samples) + 1.0, +# jnp.zeros_like(q_samples) - 1.0, +# ) +# # 2. remap q_samples +# q_samples = jnp.where( +# q_samples >= 0.5, +# 2 * (q_samples - 0.5), +# 2 * (0.5 - q_samples), +# ) +# # 3. map the quantile to chi via interpolation +# samples = jnp.interp( +# q_samples, +# self.cdf_vals, +# self.chi_axis, +# ) +# # 4. add back the sign +# samples *= sign_samples + +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_p = jnp.where( +# (variable >= self.amax) | (variable <= -self.amax), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), +# ) +# return log_p + + +# @jaxtyped(typechecker=typechecker) +# class EarthFrame(Prior): +# """ +# Prior distribution for sky location in Earth frame. +# """ + +# ifos: list = field(default_factory=list) +# gmst: float = 0.0 +# delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) + +# def __repr__(self): +# return f"EarthFrame(naming={self.naming})" + +# def __init__(self, gps: Float, ifos: list, **kwargs): +# self.naming = ["zenith", "azimuth"] +# if len(ifos) < 2: +# return ValueError( +# "At least two detectors are needed to define the Earth frame" +# ) +# elif isinstance(ifos[0], str): +# self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] +# elif isinstance(ifos[0], GroundBased2G): +# self.ifos = ifos[:1] +# else: +# return ValueError( +# "ifos should be a list of detector names or GroundBased2G objects" +# ) +# self.gmst = float( +# Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad +# ) +# self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex + +# self.transforms = { +# "azimuth": ( +# "ra", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[0], +# ), +# "zenith": ( +# "dec", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[1], +# ), +# } + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# rng_keys = jax.random.split(rng_key, 2) +# zenith = jnp.arccos( +# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) +# ) +# azimuth = jax.random.uniform( +# rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi +# ) +# return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# zenith = x["zenith"] +# azimuth = x["azimuth"] +# output = jnp.where( +# (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), +# jnp.zeros_like(0) - jnp.inf, +# jnp.zeros_like(0), +# ) +# return output + jnp.log(jnp.sin(zenith)) + + +# @jaxtyped(typechecker=typechecker) +# class Exponential(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ exp(\alpha x) +# """ + +# xmin: float = 0.0 +# xmax: float = jnp.inf +# alpha: float = -1.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: Float, +# xmax: Float, +# alpha: Union[Int, Float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin != -jnp.inf, "With negative alpha, xmin must finite" +# if alpha > 0.0: +# assert xmax != jnp.inf, "With positive alpha, xmax must finite" +# assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" +# assert self.n_dim == 1, "Exponential needs to be 1D distributions" + +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha + +# self.normalization = self.alpha / ( +# jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a exponential distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# samples = ( +# self.xmin +# + jnp.log1p( +# q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) +# ) +# / self.alpha +# ) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * variable + jnp.log(self.normalization) +# return log_p + log_in_range diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 3f0decce..8721c176 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -37,7 +37,7 @@ def inner_product( return 4.0 * jnp.real(trapezoid(integrand, dx=df)) -def m1m2_to_Mq(m1: Float, m2: Float): +def m1_m2_to_M_q(m1: Float, m2: Float): """ Transforming the primary mass m1 and secondary mass m2 to the Total mass M and mass ratio q. @@ -61,7 +61,7 @@ def m1m2_to_Mq(m1: Float, m2: Float): return M_tot, q -def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): +def M_q_to_m1_m2(trans_M_tot: Float, trans_q: Float): """ Transforming the Total mass M and mass ratio q to the primary mass m1 and secondary mass m2. @@ -87,14 +87,14 @@ def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): return m1, m2 -def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: +def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: """ - Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and + Transforming the chirp mass M_c and mass ratio q to the primary mass m1 and secondary mass m2. Parameters ---------- - Mc : Float + M_c : Float Chirp mass. q : Float Mass ratio. @@ -107,36 +107,100 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: Secondary mass. """ eta = q / (1 + q) ** 2 - M_tot = Mc / eta ** (3.0 / 5) + M_tot = M_c / eta ** (3.0 / 5) m1 = M_tot / (1 + q) m2 = m1 * q return m1, m2 -def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: +def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: """ - Transforming the right ascension ra and declination dec to the polar angle - theta and azimuthal angle phi. + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c + and mass ratio q. Parameters ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. + m1 : Float + Primary mass. + m2 : Float + Secondary mass. Returns ------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. + M_c : Float + Chirp mass. + q : Float + Mass ratio. """ - phi = ra - gmst - theta = jnp.pi / 2 - dec - return theta, phi + M_tot = m1 + m2 + eta = m1 * m2 / M_tot**2 + M_c = M_tot * eta ** (3.0 / 5) + q = m2 / m1 + return M_c, q + + +def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the total mass M + and symmetric mass ratio eta. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M : Float + Total mass. + eta : Float + Symmetric mass ratio. + """ + M = m1 + m2 + eta = m1 * m2 / M**2 + return M, eta + + +def q_to_eta(q: Float) -> Float: + """ + Transforming the chirp mass M_c and mass ratio q to the symmetric mass ratio eta. + + Parameters + ---------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + + Returns + ------- + eta : Float + Symmetric mass ratio. + """ + eta = q / (1 + q) ** 2 + return eta + + +def eta_to_q(eta: Float) -> Float: + """ + Transforming the symmetric mass ratio eta to the mass ratio q. + + Copied and modified from bilby/gw/conversion.py + + Parameters + ---------- + eta : Float + Symmetric mass ratio. + + Returns + ------- + q : Float + Mass ratio. + """ + temp = 1 / eta / 2 - 1 + return temp - (temp**2 - 1) ** 0.5 def euler_rotation(delta_x: Float[Array, " 3"]): @@ -149,11 +213,10 @@ def euler_rotation(delta_x: Float[Array, " 3"]): Copied and modified from bilby-cython/geometry.pyx """ - norm = jnp.power( - delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5 - ) + norm = jnp.linalg.vector_norm(delta_x) + cos_beta = delta_x[2] / norm - sin_beta = jnp.power(1 - cos_beta**2, 0.5) + sin_beta = jnp.sqrt(1 - cos_beta**2) alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0]) gamma = jnp.atan2(delta_x[1], delta_x[0]) @@ -182,8 +245,8 @@ def euler_rotation(delta_x: Float[Array, " 3"]): return rotation -def zenith_azimuth_to_theta_phi( - zenith: Float, azimuth: Float, delta_x: Float[Array, " 3"] +def angle_rotation( + zenith: Float, azimuth: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame. @@ -211,8 +274,6 @@ def zenith_azimuth_to_theta_phi( sin_zenith = jnp.sin(zenith) cos_zenith = jnp.cos(zenith) - rotation = euler_rotation(delta_x) - theta = jnp.acos( rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth @@ -228,7 +289,7 @@ def zenith_azimuth_to_theta_phi( + rotation[0][2] * cos_zenith, ) + 2 * jnp.pi, - (2 * jnp.pi), + 2 * jnp.pi, ) return theta, phi @@ -255,11 +316,12 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F """ ra = phi + gmst dec = jnp.pi / 2 - theta + ra = ra % (2 * jnp.pi) return ra, dec def zenith_azimuth_to_ra_dec( - zenith: Float, azimuth: Float, gmst: Float, delta_x: Float[Array, " 3"] + zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. @@ -284,12 +346,67 @@ def zenith_azimuth_to_ra_dec( dec : Float Declination. """ - theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + theta, phi = angle_rotation(zenith, azimuth, rotation) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) - ra = ra % (2 * jnp.pi) return ra, dec +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + phi = ra - gmst + theta = jnp.pi / 2 - dec + phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) + return theta, phi + + +def ra_dec_to_zenith_azimuth( + ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the right ascension and declination to the zenith angle and azimuthal angle. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + delta_x : Float + The vector pointing from the first detector to the second detector. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + """ + theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) + zenith, azimuth = angle_rotation(theta, phi, rotation) + return zenith, azimuth + + def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: """ A numerically stable method to evaluate log of diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py new file mode 100644 index 00000000..9aea6503 --- /dev/null +++ b/src/jimgw/transforms.py @@ -0,0 +1,541 @@ +from abc import ABC +from typing import Callable + +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped + +from jimgw.single_event.utils import ( + Mc_q_to_m1_m2, + m1_m2_to_Mc_q, + q_to_eta, + eta_to_q, + ra_dec_to_zenith_azimuth, + zenith_azimuth_to_ra_dec, + euler_rotation, +) + + +class Transform(ABC): + """ + Base class for transform. + The purpose of this class is purely for keeping name + """ + + name_mapping: tuple[list[str], list[str]] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + self.name_mapping = name_mapping + + def propagate_name(self, x: list[str]) -> list[str]: + input_set = set(x) + from_set = set(self.name_mapping[0]) + to_set = set(self.name_mapping[1]) + return list(input_set - from_set | to_set) + + +class NtoMTransform(Transform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + def forward(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Push forward the input x to transformed coordinate y. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + x_copy = x.copy() + output_params = self.transform_func(x_copy) + jax.tree.map( + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy + + +class NtoNTransform(NtoMTransform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + @property + def n_dim(self) -> int: + return len(self.name_mapping[0]) + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + """ + Transform the input x to transformed coordinate y and return the log Jacobian determinant. + This only works if the transform is a N -> N transform. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + log_det : Float + The log Jacobian determinant. + """ + x_copy = x.copy() + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jax.tree.map( + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy, jacobian + + +class BijectiveTransform(NtoNTransform): + + inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + """ + Inverse transform the input y to original coordinate x. + + Parameters + ---------- + y : dict[str, Float] + The transformed dictionary. + + Returns + ------- + x : dict[str, Float] + The original dictionary. + log_det : Float + The log Jacobian determinant. + """ + y_copy = y.copy() + transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) + output_params = self.inverse_transform_func(transform_params) + jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jax.tree.map( + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy, jacobian + + def backward(self, y: dict[str, Float]) -> dict[str, Float]: + """ + Pull back the input y to original coordinate x and return the log Jacobian determinant. + + Parameters + ---------- + y : dict[str, Float] + The transformed dictionary. + + Returns + ------- + x : dict[str, Float] + The original dictionary. + """ + y_copy = y.copy() + output_params = self.inverse_transform_func(y_copy) + jax.tree.map( + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy + + +@jaxtyped(typechecker=typechecker) +class ScaleTransform(BijectiveTransform): + scale: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + scale: Float, + ): + super().__init__(name_mapping) + self.scale = scale + self.transform_func = lambda x: { + name_mapping[1][i]: x[name_mapping[0][i]] * self.scale + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: x[name_mapping[1][i]] / self.scale + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class OffsetTransform(BijectiveTransform): + offset: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + offset: Float, + ): + super().__init__(name_mapping) + self.offset = offset + self.transform_func = lambda x: { + name_mapping[1][i]: x[name_mapping[0][i]] + self.offset + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: x[name_mapping[1][i]] - self.offset + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class LogitTransform(BijectiveTransform): + """ + Logit transform following + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: 1 / (1 + jnp.exp(-x[name_mapping[0][i]])) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.log( + x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]]) + ) + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class ArcSineTransform(BijectiveTransform): + """ + ArcSine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.arcsin(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.sin(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class BoundToBound(BijectiveTransform): + """ + Bound to bound transformation + """ + + original_lower_bound: Float[Array, " n_dim"] + original_upper_bound: Float[Array, " n_dim"] + target_lower_bound: Float[Array, " n_dim"] + target_upper_bound: Float[Array, " n_dim"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float[Array, " n_dim"], + original_upper_bound: Float[Array, " n_dim"], + target_lower_bound: Float[Array, " n_dim"], + target_upper_bound: Float[Array, " n_dim"], + ): + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + self.target_lower_bound = target_lower_bound + self.target_upper_bound = target_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: (x[name_mapping[0][i]] - self.original_lower_bound[i]) + * (self.target_upper_bound[i] - self.target_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) + + self.target_lower_bound[i] + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (x[name_mapping[1][i]] - self.target_lower_bound[i]) + * (self.original_upper_bound[i] - self.original_lower_bound[i]) + / (self.target_upper_bound[i] - self.target_lower_bound[i]) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class BoundToUnbound(BijectiveTransform): + """ + Bound to unbound transformation + """ + + original_lower_bound: Float + original_upper_bound: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, + original_upper_bound: Float, + ): + + def logit(x): + return jnp.log(x / (1 - x)) + + super().__init__(name_mapping) + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.original_upper_bound = jnp.atleast_1d(original_upper_bound) + + self.transform_func = lambda x: { + name_mapping[1][i]: logit( + (x[name_mapping[0][i]] - self.original_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) + ) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (self.original_upper_bound[i] - self.original_lower_bound[i]) + / (1 + jnp.exp(-x[name_mapping[1][i]])) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class SingleSidedUnboundTransform(BijectiveTransform): + """ + Unbound upper limit transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to component masses + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + m1, m2 = Mc_q_to_m1_m2(Mc, q) + return {name_mapping[1][0]: m1, name_mapping[1][1]: m2} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1 = x[name_mapping[1][0]] + m2 = x[name_mapping[1][1]] + Mc, q = m1_m2_to_Mc_q(m1, m2) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + +@jaxtyped(typechecker=typechecker) +class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform mass ratio to symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + self.transform_func = lambda x: { + name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]]) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]]) + } + + +@jaxtyped(typechecker=typechecker) +class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): + """ + Transform sky frame to detector frame sky position + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + rotation: Float[Array, " 3 3"] + rotation_inv: Float[Array, " 3 3"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gmst: Float, + delta_x: Float, + ): + super().__init__(name_mapping) + + self.gmst = gmst + self.rotation = euler_rotation(delta_x) + self.rotation_inv = jnp.linalg.inv(self.rotation) + + def named_transform(x): + ra = x[name_mapping[0][0]] + dec = x[name_mapping[0][1]] + zenith, azimuth = ra_dec_to_zenith_azimuth( + ra, dec, self.gmst, self.rotation + ) + return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth} + + self.transform_func = named_transform + + def named_inverse_transform(x): + zenith = x[name_mapping[1][0]] + azimuth = x[name_mapping[1][1]] + ra, dec = zenith_azimuth_to_ra_dec( + zenith, azimuth, self.gmst, self.rotation_inv + ) + return {name_mapping[0][0]: ra, name_mapping[0][1]: dec} + + self.inverse_transform_func = named_inverse_transform + + +# class PowerLawTransform(UnivariateTransform): +# """ +# PowerLaw transformation +# Parameters +# ---------- +# name_mapping : tuple[list[str], list[str]] +# The name mapping between the input and output dictionary. +# """ + +# xmin: Float +# xmax: Float +# alpha: Float + +# def __init__( +# self, +# name_mapping: tuple[list[str], list[str]], +# xmin: Float, +# xmax: Float, +# alpha: Float, +# ): +# super().__init__(name_mapping) +# self.xmin = xmin +# self.xmax = xmax +# self.alpha = alpha +# self.transform_func = lambda x: ( +# self.xmin ** (1.0 + self.alpha) +# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) + + +# class ParetoTransform(UnivariateTransform): +# """ +# Pareto transformation: Power law when alpha = -1 +# Parameters +# ---------- +# name_mapping : tuple[list[str], list[str]] +# The name mapping between the input and output dictionary. +# """ + +# def __init__( +# self, +# name_mapping: tuple[list[str], list[str]], +# xmin: Float, +# xmax: Float, +# ): +# super().__init__(name_mapping) +# self.xmin = xmin +# self.xmax = xmax +# self.transform_func = lambda x: self.xmin * jnp.exp( +# x * jnp.log(self.xmax / self.xmin) +# ) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py new file mode 100644 index 00000000..9adda574 --- /dev/null +++ b/test/integration/test_GW150914.py @@ -0,0 +1,131 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound, MassRatioToSymmetricMassRatioTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior( + 0.125, + 1., + parameter_names=["q"], +) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +# Current likelihood sampling will fail and give nan because of large number +dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) +] + +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) diff --git a/test/test_prior.py b/test/test_prior.py deleted file mode 100644 index 53b065da..00000000 --- a/test/test_prior.py +++ /dev/null @@ -1 +0,0 @@ -from jimgw.prior import Composite, Unconstrained_Uniform, Uniform diff --git a/test/test_detector.py b/test/unit/test_detector.py similarity index 100% rename from test/test_detector.py rename to test/unit/test_detector.py diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py new file mode 100644 index 00000000..20d71de4 --- /dev/null +++ b/test/unit/test_prior.py @@ -0,0 +1,123 @@ +from jimgw.prior import * +import scipy.stats as stats + + +class TestUnivariatePrior: + def test_logistic(self): + p = LogisticDistribution(["x"]) + # Check that the log_prob is finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Cross-check log_prob with scipy.stats.logistic + x = jnp.linspace(-10.0, 10.0, 1000) + assert jnp.allclose(jax.vmap(p.log_prob)(p.add_name(x[None])), stats.logistic.logpdf(x)) + + def test_standard_normal(self): + p = StandardNormalDistribution(["x"]) + # Check that the log_prob is finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Cross-check log_prob with scipy.stats.norm + x = jnp.linspace(-10.0, 10.0, 1000) + assert jnp.allclose(jax.vmap(p.log_prob)(p.add_name(x[None])), stats.norm.logpdf(x)) + + def test_uniform(self): + xmin, xmax = -10.0, 10.0 + p = UniformPrior(xmin, xmax, ["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) + # Check that the log_prob is correct in the support + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin))) + + def test_sine(self): + p = SinePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) + # Check that the log_prob is finite + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Check that the log_prob is correct in the support + x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) + y = jax.vmap(p.base_prior.base_prior.transform)(x) + y = jax.vmap(p.base_prior.transform)(y) + y = jax.vmap(p.transform)(y) + assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.sin(y['x'])/2.0)) + + def test_cosine(self): + p = CosinePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) + # Check that the log_prob is finite + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Check that the log_prob is correct in the support + x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) + y = jax.vmap(p.base_prior.transform)(x) + y = jax.vmap(p.transform)(y) + assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.cos(y['x'])/2.0)) + + def test_uniform_sphere(self): + p = UniformSpherePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x_mag'])) + assert jnp.all(jnp.isfinite(samples['x_theta'])) + assert jnp.all(jnp.isfinite(samples['x_phi'])) + # Check that the log_prob is finite + samples = {} + for i in range(3): + samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000)) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + + def test_power_law(self): + def powerlaw_log_pdf(x, alpha, xmin, xmax): + if alpha == -1.0: + normalization = 1./(jnp.log(xmax) - jnp.log(xmin)) + else: + normalization = (1.0 + alpha) / (xmax**(1.0 + alpha) - xmin**(1.0 + alpha)) + return jnp.log(normalization) + alpha * jnp.log(x) + + def func(alpha): + xmin = 0.05 + xmax = 10.0 + p = PowerLawPrior(xmin, xmax, alpha, ["x"]) + # Check that all the samples are finite + powerlaw_samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) + + # Check that all the log_probs are finite + samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base'] + base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples}) + assert jnp.all(jnp.isfinite(base_log_p)) + + # Check that the log_prob is correct in the support + samples = jnp.linspace(-10.0, 10.0, 1000) + transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x'] + # cut off the samples that are outside the support + samples = samples[transformed_samples >= xmin] + transformed_samples = transformed_samples[transformed_samples >= xmin] + samples = samples[transformed_samples <= xmax] + transformed_samples = transformed_samples[transformed_samples <= xmax] + # log pdf of powerlaw + assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) + + # Test Pareto Transform + func(-1.0) + # Test other values of alpha + print("Testing PowerLawPrior") + positive_alpha = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] + for alpha_val in positive_alpha: + func(alpha_val) + negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] + for alpha_val in negative_alpha: + func(alpha_val)