From 77e0247c512d0c5ed6c9d736e6015b3fb3160490 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Tue, 23 Jul 2024 13:55:23 -0400 Subject: [PATCH 001/172] Update Single_event_runManager.py --- example/Single_event_runManager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index c0878afc..88ffe52b 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -31,7 +31,6 @@ run = SingleEventRun( seed=0, - path="test_data/GW150914/", detectors=["H1", "L1"], priors={ "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0}, From 75c74a36517147f17e785cdb8dd17e2984279fac Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 10:16:11 -0400 Subject: [PATCH 002/172] scaffolding jim to handle naming. Also isort some import in prior.py --- src/jimgw/jim.py | 8 +++++++- src/jimgw/prior.py | 7 ++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 0ab1105a..15acda7c 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -13,9 +13,15 @@ class Jim(object): """ Master class for interfacing with flowMC - """ + likelihood: LikelihoodBase + prior: Prior + + seed: int + + parameter_names: list[str] + def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): self.Likelihood = likelihood self.Prior = prior diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f1827753..702dcd31 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -3,12 +3,13 @@ import jax import jax.numpy as jnp +from astropy.time import Time +from beartype import beartype as typechecker from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped -from beartype import beartype as typechecker -from jimgw.single_event.utils import zenith_azimuth_to_ra_dec + from jimgw.single_event.detector import GroundBased2G, detector_preset -from astropy.time import Time +from jimgw.single_event.utils import zenith_azimuth_to_ra_dec class Prior(Distribution): From c749dbd06b512a5c9285889f570012a229522a2e Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 10:19:32 -0400 Subject: [PATCH 003/172] Rename variables in jim --- src/jimgw/jim.py | 46 ++++++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 15acda7c..39c348fd 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -18,13 +18,19 @@ class Jim(object): likelihood: LikelihoodBase prior: Prior - seed: int - parameter_names: list[str] + sampler: Sampler - def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): - self.Likelihood = likelihood - self.Prior = prior + def __init__( + self, + likelihood: LikelihoodBase, + prior: Prior, + parameter_names: list[str], + **kwargs, + ): + self.likelihood = likelihood + self.prior = prior + self.parameter_names = parameter_names seed = kwargs.get("seed", 0) @@ -39,11 +45,11 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): rng_key, subkey = jax.random.split(rng_key) model = MaskedCouplingRQSpline( - self.Prior.n_dim, num_layers, hidden_size, num_bins, subkey + self.prior.n_dim, num_layers, hidden_size, num_bins, subkey ) self.Sampler = Sampler( - self.Prior.n_dim, + self.prior.n_dim, rng_key, None, # type: ignore local_sampler, @@ -52,15 +58,15 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): ) def posterior(self, params: Float[Array, " n_dim"], data: dict): - prior_params = self.Prior.add_name(params.T) - prior = self.Prior.log_prob(prior_params) + prior_params = self.prior.add_name(params.T) + prior = self.prior.log_prob(prior_params) return ( - self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior + self.likelihood.evaluate(self.prior.transform(prior_params), data) + prior ) def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: - initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains) + initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T self.Sampler.sample(initial_guess, None) # type: ignore @@ -73,7 +79,7 @@ def maximize_likelihood( ): key = jax.random.PRNGKey(seed) set_nwalkers = set_nwalkers - initial_guess = self.Prior.sample(key, set_nwalkers) + initial_guess = self.prior.sample(key, set_nwalkers) def negative_posterior(x: Float[Array, " n_dim"]): return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet @@ -84,7 +90,7 @@ def negative_posterior(x: Float[Array, " n_dim"]): print("Done compiling") print("Starting the optimizer") - optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True) + optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True) _ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return best_fit @@ -98,19 +104,19 @@ def print_summary(self, transform: bool = True): train_summary = self.Sampler.get_sampler_state(training=True) production_summary = self.Sampler.get_sampler_state(training=False) - training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T - training_chain = self.Prior.add_name(training_chain) + training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T + training_chain = self.prior.add_name(training_chain) if transform: - training_chain = self.Prior.transform(training_chain) + training_chain = self.prior.transform(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T - production_chain = self.Prior.add_name(production_chain) + production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T + production_chain = self.prior.add_name(production_chain) if transform: - production_chain = self.Prior.transform(production_chain) + production_chain = self.prior.transform(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -166,7 +172,7 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.Sampler.get_sampler_state(training=False)["chains"] - chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1))) + chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1))) return chains def plot(self): From f45b41e83822888adba6ea358794e832c7db0952 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 11:15:58 -0400 Subject: [PATCH 004/172] Starting tracking names in jim --- src/jimgw/jim.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 39c348fd..08b82b05 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -18,6 +18,7 @@ class Jim(object): likelihood: LikelihoodBase prior: Prior + # Name of parameters to sample from parameter_names: list[str] sampler: Sampler @@ -57,12 +58,22 @@ def __init__( **kwargs, ) + def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: + """ + Turn an array into a dictionary + + Parameters + ---------- + x : Array + An array of parameters. Shape (n_dim,). + """ + + return dict(zip(self.parameter_names, x)) + def posterior(self, params: Float[Array, " n_dim"], data: dict): - prior_params = self.prior.add_name(params.T) - prior = self.prior.log_prob(prior_params) - return ( - self.likelihood.evaluate(self.prior.transform(prior_params), data) + prior - ) + named_params = self.add_name(params) + prior = self.prior.log_prob(named_params) + return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: From dd6e0de37c6d4ca63aa55bd77b7108e8ca51cef0 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 13:53:00 -0400 Subject: [PATCH 005/172] Add Transform class --- src/jimgw/transforms.py | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/jimgw/transforms.py diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py new file mode 100644 index 00000000..0795d81d --- /dev/null +++ b/src/jimgw/transforms.py @@ -0,0 +1,63 @@ +from dataclasses import field +from typing import Callable, Union + +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped + + +class Transform: + """ + Base class for transform. + + The idea of transform should be used on distribtuion, + """ + + transform_func: Callable[[Float[Array, " N"]], Float[Array, " M"]] + jacobian_func: Callable[[Float[Array, " N"]], Float] + name_mapping: tuple[list[str], list[str]] + + def __init__( + self, + transform_func: Callable[[Float[Array, " N"]], Float[Array, " M"]], + name_mapping: tuple[list[str], list[str]], + ): + self.transform_func = transform_func + self.jacobian_func = jax.jacfwd(transform_func) + self.name_mapping = name_mapping + + def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + """ + Transform the input x to transformed coordinate y and return the log Jacobian determinant. + This only works if the transform is a N -> N transform. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + log_det : Float + The log Jacobian determinant. + """ + return self.transform_func(x), jnp.log(jnp.linalg.det(self.jacobian_func(x))) + + def push_forward(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Push forward the input x to transformed coordinate y. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + return self.transform_func(x) \ No newline at end of file From f11f8c889d69caa2dc0b5a1dd2bcc58857bb32e0 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 14:42:34 -0400 Subject: [PATCH 006/172] Add LogitToUniform Transform --- src/jimgw/transforms.py | 64 ++++++++++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 0795d81d..a261a7fe 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,33 +1,35 @@ +from abc import ABC, abstractmethod from dataclasses import field from typing import Callable, Union +from chex import assert_rank import jax import jax.numpy as jnp from beartype import beartype as typechecker -from jaxtyping import Float, Array, jaxtyped +from jaxtyping import Array, Float, jaxtyped -class Transform: +class Transform(ABC): """ Base class for transform. The idea of transform should be used on distribtuion, """ - transform_func: Callable[[Float[Array, " N"]], Float[Array, " M"]] - jacobian_func: Callable[[Float[Array, " N"]], Float] name_mapping: tuple[list[str], list[str]] + transform_func: Callable[[dict[str, Float]], dict[str, Float]] def __init__( self, - transform_func: Callable[[Float[Array, " N"]], Float[Array, " M"]], name_mapping: tuple[list[str], list[str]], ): - self.transform_func = transform_func - self.jacobian_func = jax.jacfwd(transform_func) self.name_mapping = name_mapping def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + return self.transform(x) + + @abstractmethod + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. This only works if the transform is a N -> N transform. @@ -44,9 +46,10 @@ def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - return self.transform_func(x), jnp.log(jnp.linalg.det(self.jacobian_func(x))) - - def push_forward(self, x: dict[str, Float]) -> dict[str, Float]: + raise NotImplementedError + + @abstractmethod + def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ Push forward the input x to transformed coordinate y. @@ -60,4 +63,43 @@ def push_forward(self, x: dict[str, Float]) -> dict[str, Float]: y : dict[str, Float] The transformed dictionary. """ - return self.transform_func(x) \ No newline at end of file + raise NotImplementedError + +class LogitToUniform(Transform): + """ + Transform from unconstrained space to uniform space. + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + bounds : tuple[Float, Float] + The lower and upper bounds of the uniform distribution. + + """ + + bounds: tuple[Float, Float] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + bounds: tuple[Float, Float], + ): + super().__init__(name_mapping) + self.bounds = bounds + self.transform_func = lambda x: (self.bounds[1] - self.bounds[0]) / (1 + jnp.exp(-x)) + self.bounds[0] + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + input_params = x.pop(self.name_mapping[0][0]) + assert_rank(input_params, 0) + output_params = self.transform_func(input_params) + jacobian = jax.jacfwd(self.transform_func)(input_params) + x[self.name_mapping[1][0]] = output_params + return x, jnp.log(jacobian) + + def forward(self, x: dict[str, Float]) -> dict[str, Float]: + input_params = x.pop(self.name_mapping[0][0]) + assert_rank(input_params, 0) + output_params = self.transform_func(input_params) + x[self.name_mapping[1][0]] = output_params + return x \ No newline at end of file From c1192f2c3a1031ea0b6a68533c27394b03a96732 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:24:11 -0400 Subject: [PATCH 007/172] Add logit distribution --- src/jimgw/prior.py | 79 +++++++++++++++++++++-------------------- src/jimgw/transforms.py | 2 +- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 702dcd31..e7b1acda 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -23,57 +23,19 @@ class Prior(Distribution): """ naming: list[str] - transforms: dict[str, tuple[str, Callable]] = field(default_factory=dict) @property def n_dim(self): return len(self.naming) - def __init__( - self, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {} - ): + def __init__(self, naming: list[str]): """ Parameters ---------- naming : list[str] A list of names for the parameters of the prior. - transforms : dict[tuple[str,Callable]] - A dictionary of transforms to apply to the parameters. The keys are - the names of the parameters and the values are a tuple of the name - of the transform and the transform itself. """ self.naming = naming - self.transforms = {} - - def make_lambda(name): - return lambda x: x[name] - - for name in naming: - if name in transforms: - self.transforms[name] = transforms[name] - else: - # Without the function, the lambda will refer to the variable name instead of its value, - # which will make lambda reference the last value of the variable name - self.transforms[name] = (name, make_lambda(name)) - - def transform(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Apply the transforms to the parameters. - - Parameters - ---------- - x : dict - A dictionary of parameters. Names should match the ones in the prior. - - Returns - ------- - x : dict - A dictionary of parameters with the transforms applied. - """ - output = {} - for value in self.transforms.values(): - output[value[0]] = value[1](x) - return output def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ @@ -96,6 +58,45 @@ def log_prob(self, x: dict[str, Array]) -> Float: raise NotImplementedError +@jaxtyped(typechecker=typechecker) +class Logit(Prior): + + def __repr__(self): + return f"Logit(naming={self.naming})" + + def __init__(self, naming: list[str], **kwargs): + super().__init__(naming) + assert self.n_dim == 1, "Logit needs to be 1D distributions" + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a logit distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) + samples = jnp.log(samples / (1 - samples)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Float]) -> Float: + variable = x[self.naming[0]] + return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) + + +# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) class Uniform(Prior): xmin: float = 0.0 diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index a261a7fe..d8402cfe 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -2,10 +2,10 @@ from dataclasses import field from typing import Callable, Union -from chex import assert_rank import jax import jax.numpy as jnp from beartype import beartype as typechecker +from chex import assert_rank from jaxtyping import Array, Float, jaxtyped From b3ebc5efd21f958aa4d931a4954f48d2828adcab Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:42:35 -0400 Subject: [PATCH 008/172] Add propagate)name method in transform --- src/jimgw/transforms.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8402cfe..0ee32927 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -27,13 +27,13 @@ def __init__( def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: return self.transform(x) - + @abstractmethod def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. This only works if the transform is a N -> N transform. - + Parameters ---------- x : dict[str, Float] @@ -63,7 +63,14 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: y : dict[str, Float] The transformed dictionary. """ - raise NotImplementedError + raise NotImplementedError + + def propagate_name(self, x: list[str]) -> list[str]: + input_set = set(x) + from_set = set(self.name_mapping[0]) + to_set = set(self.name_mapping[1]) + return list(input_set - from_set | to_set) + class LogitToUniform(Transform): """ @@ -87,7 +94,10 @@ def __init__( ): super().__init__(name_mapping) self.bounds = bounds - self.transform_func = lambda x: (self.bounds[1] - self.bounds[0]) / (1 + jnp.exp(-x)) + self.bounds[0] + self.transform_func = ( + lambda x: (self.bounds[1] - self.bounds[0]) / (1 + jnp.exp(-x)) + + self.bounds[0] + ) def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: input_params = x.pop(self.name_mapping[0][0]) @@ -102,4 +112,4 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: assert_rank(input_params, 0) output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params - return x \ No newline at end of file + return x From 95ef89b0fc466e1efdf15f9a4f0a8284a434343d Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:43:00 -0400 Subject: [PATCH 009/172] Rename composite to Combine for combining priors --- src/jimgw/prior.py | 118 +++++++++++++++++++++++++++------------------ 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e7b1acda..b257ef38 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,6 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec +from jimgw.transforms import Transform class Prior(Distribution): @@ -22,20 +23,20 @@ class Prior(Distribution): the names of the parameters and the transforms that are applied to them. """ - naming: list[str] + parameter_names: list[str] @property def n_dim(self): - return len(self.naming) + return len(self.parameter_names) - def __init__(self, naming: list[str]): + def __init__(self, parameter_names: list[str]): """ Parameters ---------- - naming : list[str] + parameter_names : list[str] A list of names for the parameters of the prior. """ - self.naming = naming + self.parameter_names = parameter_names def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ @@ -47,8 +48,8 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: An array of parameters. Shape (n_dim,). """ - return dict(zip(self.naming, x)) - + return dict(zip(self.parameter_names, x)) + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -62,10 +63,10 @@ def log_prob(self, x: dict[str, Array]) -> Float: class Logit(Prior): def __repr__(self): - return f"Logit(naming={self.naming})" + return f"Logit(parameter_names={self.parameter_names})" - def __init__(self, naming: list[str], **kwargs): - super().__init__(naming) + def __init__(self, parameter_names: list[str], **kwargs): + super().__init__(parameter_names) assert self.n_dim == 1, "Logit needs to be 1D distributions" def sample( @@ -92,10 +93,68 @@ def sample( return self.add_name(samples[None]) def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] + variable = x[self.parameter_names[0]] return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) +class Sequential(Prior): + """ + A prior class constructed + """ + + members: list[Prior | Transform] = field(default_factory=list) + + def __repr__(self): + return ( + f"Sequential(priors={self.members}, parameter_names={self.parameter_names})" + ) + + def __init( + self, + members: list[Prior | Transform], + ): + self.members = members + +class Combine(Prior): + """ + A prior class constructed by joinning multiple priors together to form a multivariate prior. + This assumes the priors composing the Combine class are independent. + """ + + priors: list[Prior] = field(default_factory=list) + + def __repr__(self): + return ( + f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" + ) + + def __init__( + self, + priors: list[Prior], + **kwargs, + ): + parameter_names = [] + for prior in priors: + parameter_names += prior.parameter_names + self.priors = priors + self.parameter_names = parameter_names + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + output = {} + for prior in self.priors: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output + + def log_prob(self, x: dict[str, Float]) -> Float: + output = 0.0 + for prior in self.priors: + output += prior.log_prob(x) + return output + + # ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) class Uniform(Prior): @@ -686,40 +745,3 @@ def log_prob(self, x: dict[str, Array]) -> Float: - 0.5 * ((variable - self.mean) / self.std) ** 2 ) return output - - -class Composite(Prior): - priors: list[Prior] = field(default_factory=list) - - def __repr__(self): - return f"Composite(priors={self.priors}, naming={self.naming})" - - def __init__( - self, - priors: list[Prior], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - naming = [] - self.transforms = {} - for prior in priors: - naming += prior.naming - self.transforms.update(prior.transforms) - self.priors = priors - self.naming = naming - self.transforms.update(transforms) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - output = {} - for prior in self.priors: - rng_key, subkey = jax.random.split(rng_key) - output.update(prior.sample(subkey, n_samples)) - return output - - def log_prob(self, x: dict[str, Float]) -> Float: - output = 0.0 - for prior in self.priors: - output += prior.log_prob(x) - return output From e6e45efc00b1790a4e0d8ef087aab51bf627472d Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:46:19 -0400 Subject: [PATCH 010/172] scaffold prior test --- test/test_prior.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/test_prior.py b/test/test_prior.py index 53b065da..98b8b1c4 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -1 +1,16 @@ -from jimgw.prior import Composite, Unconstrained_Uniform, Uniform +from jimgw.prior import * + + +class TestUnivariatePrior: + def test_logit(self): + p = Logit() + +class TestPriorOperations: + def test_combine(self): + raise NotImplementedError + + def test_sequence(self): + raise NotImplementedError + + def test_factor(self): + raise NotImplementedError \ No newline at end of file From d503a37e6089f395ffdfa934ee6cd9f107e66aa8 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:49:02 -0400 Subject: [PATCH 011/172] Add Sequential Transform --- src/jimgw/prior.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index b257ef38..3dbd6491 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -97,12 +97,13 @@ def log_prob(self, x: dict[str, Float]) -> Float: return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) -class Sequential(Prior): +class SequentialTransform(Prior): """ - A prior class constructed + Transform a prior distribution by applying a sequence of transforms. """ - members: list[Prior | Transform] = field(default_factory=list) + base_prior: Prior + transforms: list[Transform] def __repr__(self): return ( @@ -111,9 +112,30 @@ def __repr__(self): def __init( self, - members: list[Prior | Transform], + base_prior: Prior, + transforms: list[Transform], ): - self.members = members + + self.base_prior = base_prior + self.transforms = transforms + self.parameter_names = base_prior.parameter_names + for transform in transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + output = self.base_prior.sample(rng_key, n_samples) + for transform in self.transforms: + output, _ = transform.forward(output) + return output + + def log_prob(self, x: dict[str, Float]) -> Float: + output = self.base_prior.log_prob(x) + for transform in self.transforms: + _, log_jacobian = transform.transform(x) + output += log_jacobian + return output class Combine(Prior): """ From 61b3b9ec885fac94f4545cb748c64f5e07deb311 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 16:03:26 -0400 Subject: [PATCH 012/172] Separate logit and scaling. Also rename prior --- src/jimgw/prior.py | 2 +- src/jimgw/transforms.py | 75 ++++++++++++++++++++++++++++++----------- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3dbd6491..30442662 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -60,7 +60,7 @@ def log_prob(self, x: dict[str, Array]) -> Float: @jaxtyped(typechecker=typechecker) -class Logit(Prior): +class LogisticDistribution(Prior): def __repr__(self): return f"Logit(parameter_names={self.parameter_names})" diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 0ee32927..a8d0a908 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -72,32 +72,13 @@ def propagate_name(self, x: list[str]) -> list[str]: return list(input_set - from_set | to_set) -class LogitToUniform(Transform): - """ - Transform from unconstrained space to uniform space. - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - bounds : tuple[Float, Float] - The lower and upper bounds of the uniform distribution. - - """ - - bounds: tuple[Float, Float] +class UnivariateTransform(Transform): def __init__( self, name_mapping: tuple[list[str], list[str]], - bounds: tuple[Float, Float], ): super().__init__(name_mapping) - self.bounds = bounds - self.transform_func = ( - lambda x: (self.bounds[1] - self.bounds[0]) / (1 + jnp.exp(-x)) - + self.bounds[0] - ) def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: input_params = x.pop(self.name_mapping[0][0]) @@ -113,3 +94,57 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x + + +class ScaleToRange(UnivariateTransform): + + range: tuple[Float, Float] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + range: tuple[Float, Float], + ): + super().__init__(name_mapping) + self.range = range + self.transform_func = ( + lambda x: (self.range[1] - self.range[0]) * x + self.range[0] + ) + + +class Logit(Transform): + """ + Logit transform following + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + + +class Sine(Transform): + """ + Transform from unconstrained space to uniform space. + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: jnp.sin(x) From b47e6aed93c2abe1777aa5a6c8a150164d14e64d Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 16:10:08 -0400 Subject: [PATCH 013/172] Univeriate Transform seems working --- src/jimgw/prior.py | 8 ++++---- src/jimgw/transforms.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 30442662..a9274259 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -63,7 +63,7 @@ def log_prob(self, x: dict[str, Array]) -> Float: class LogisticDistribution(Prior): def __repr__(self): - return f"Logit(parameter_names={self.parameter_names})" + return f"Logistic(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) @@ -107,10 +107,10 @@ class SequentialTransform(Prior): def __repr__(self): return ( - f"Sequential(priors={self.members}, parameter_names={self.parameter_names})" + f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" ) - def __init( + def __init__( self, base_prior: Prior, transforms: list[Transform], @@ -127,7 +127,7 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) for transform in self.transforms: - output, _ = transform.forward(output) + output = jax.vmap(transform.forward)(output) return output def log_prob(self, x: dict[str, Float]) -> Float: diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index a8d0a908..52ec43be 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -112,7 +112,7 @@ def __init__( ) -class Logit(Transform): +class Logit(UnivariateTransform): """ Logit transform following @@ -131,7 +131,7 @@ def __init__( self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) -class Sine(Transform): +class Sine(UnivariateTransform): """ Transform from unconstrained space to uniform space. From e1cc408c2400807d35492a875e5ba1d10823e382 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 16:37:18 -0400 Subject: [PATCH 014/172] Uniform prior now working. Also add prior test --- src/jimgw/prior.py | 133 ++++++++++++++++++---------------------- src/jimgw/transforms.py | 22 ++++--- test/test_prior.py | 8 ++- 3 files changed, 80 insertions(+), 83 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index a9274259..a7a81acc 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform +from jimgw.transforms import Transform, Logit, Scale, Offset class Prior(Distribution): @@ -134,104 +134,87 @@ def log_prob(self, x: dict[str, Float]) -> Float: output = self.base_prior.log_prob(x) for transform in self.transforms: _, log_jacobian = transform.transform(x) - output += log_jacobian + output -= log_jacobian return output -class Combine(Prior): - """ - A prior class constructed by joinning multiple priors together to form a multivariate prior. - This assumes the priors composing the Combine class are independent. - """ +# class Combine(Prior): +# """ +# A prior class constructed by joinning multiple priors together to form a multivariate prior. +# This assumes the priors composing the Combine class are independent. +# """ + +# priors: list[Prior] = field(default_factory=list) + +# def __repr__(self): +# return ( +# f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" +# ) + +# def __init__( +# self, +# priors: list[Prior], +# **kwargs, +# ): +# parameter_names = [] +# for prior in priors: +# parameter_names += prior.parameter_names +# self.priors = priors +# self.parameter_names = parameter_names + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# output = {} +# for prior in self.priors: +# rng_key, subkey = jax.random.split(rng_key) +# output.update(prior.sample(subkey, n_samples)) +# return output + +# def log_prob(self, x: dict[str, Float]) -> Float: +# output = 0.0 +# for prior in self.priors: +# output -= prior.log_prob(x) +# return output - priors: list[Prior] = field(default_factory=list) - def __repr__(self): - return ( - f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" - ) - def __init__( - self, - priors: list[Prior], - **kwargs, - ): - parameter_names = [] - for prior in priors: - parameter_names += prior.parameter_names - self.priors = priors - self.parameter_names = parameter_names - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - output = {} - for prior in self.priors: - rng_key, subkey = jax.random.split(rng_key) - output.update(prior.sample(subkey, n_samples)) - return output - - def log_prob(self, x: dict[str, Float]) -> Float: - output = 0.0 - for prior in self.priors: - output += prior.log_prob(x) - return output - - -# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) class Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 + _dist: Prior + + xmin: float + xmax: float def __repr__(self): return f"Uniform(xmin={self.xmin}, xmax={self.xmax})" def __init__( self, - xmin: Float, - xmax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, + xmin: float, + xmax: float, + parameter_names: list[str], ): - super().__init__(naming, transforms) + super().__init__(parameter_names) assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin + self._dist = SequentialTransform( + LogisticDistribution(parameter_names), + [ + Logit((parameter_names, parameter_names)), + Scale((parameter_names, parameter_names), xmax - xmin), + Offset((parameter_names, parameter_names), xmin), + ]) def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a uniform distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - samples = jax.random.uniform( - rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax - ) - return self.add_name(samples[None]) + return self._dist.sample(rng_key, n_samples) def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - return output + jnp.log(1.0 / (self.xmax - self.xmin)) + return self._dist.log_prob(x) +# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) class Unconstrained_Uniform(Prior): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 52ec43be..203f6bf1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -95,21 +95,29 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: x[self.name_mapping[1][0]] = output_params return x +class Scale(UnivariateTransform): + scale: Float -class ScaleToRange(UnivariateTransform): + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + scale: Float, + ): + super().__init__(name_mapping) + self.scale = scale + self.transform_func = lambda x: x * self.scale - range: tuple[Float, Float] +class Offset(UnivariateTransform): + offset: Float def __init__( self, name_mapping: tuple[list[str], list[str]], - range: tuple[Float, Float], + offset: Float, ): super().__init__(name_mapping) - self.range = range - self.transform_func = ( - lambda x: (self.range[1] - self.range[0]) * x + self.range[0] - ) + self.offset = offset + self.transform_func = lambda x: x + self.offset class Logit(UnivariateTransform): diff --git a/test/test_prior.py b/test/test_prior.py index 98b8b1c4..6890f9b9 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -2,9 +2,15 @@ class TestUnivariatePrior: - def test_logit(self): + def test_logistic(self): p = Logit() + def test_uniform(self): + p = Uniform(0.0, 10.0, ['x']) + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.allclose(log_prob, -jnp.log(10.0)) + class TestPriorOperations: def test_combine(self): raise NotImplementedError From d68cb366e3d8b3ce22aa0dc8f2bfb1f412eb747a Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 17:10:29 -0400 Subject: [PATCH 015/172] Added inverse transform and uniform now perform correct --- src/jimgw/prior.py | 83 +++-------------------------------------- src/jimgw/transforms.py | 56 ++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 80 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index a7a81acc..35722707 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -131,10 +131,11 @@ def sample( return output def log_prob(self, x: dict[str, Float]) -> Float: - output = self.base_prior.log_prob(x) - for transform in self.transforms: - _, log_jacobian = transform.transform(x) - output -= log_jacobian + output = 0.0 + for transform in reversed(self.transforms): + x, log_jacobian = transform.inverse_transform(x) + output += log_jacobian + output += self.base_prior.log_prob(x) return output # class Combine(Prior): @@ -216,80 +217,6 @@ def log_prob(self, x: dict[str, Array]) -> Float: # ====================== Things below may need rework ====================== -@jaxtyped(typechecker=typechecker) -class Unconstrained_Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 - - def __repr__(self): - return f"Unconstrained_Uniform(xmin={self.xmin}, xmax={self.xmax})" - - def __init__( - self, - xmin: Float, - xmax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - local_transform = self.transforms - - def new_transform(param): - param[self.naming[0]] = self.to_range(param[self.naming[0]]) - return local_transform[self.naming[0]][1](param) - - self.transforms = { - self.naming[0]: (local_transform[self.naming[0]][0], new_transform) - } - - def to_range(self, x: Float) -> Float: - """ - Transform the parameters to the range of the prior. - - Parameters - ---------- - x : Float - The parameters to transform. - - Returns - ------- - x : dict - A dictionary of parameters with the transforms applied. - """ - return (self.xmax - self.xmin) / (1 + jnp.exp(-x)) + self.xmin - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a uniform distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : - An array of shape (n_samples, n_dim) containing the samples. - - """ - samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) - samples = jnp.log(samples / (1 - samples)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) - - class Sphere(Prior): """ A prior on a sphere represented by Cartesian coordinates. diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 203f6bf1..c5e9120a 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -18,6 +18,7 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] + inverse_func: Callable[[dict[str, Float]], dict[str, Float]] def __init__( self, @@ -47,6 +48,23 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError + + @abstractmethod + def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Inverse transform the input x to transformed coordinate y. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + raise NotImplementedError @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: @@ -64,6 +82,23 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: The transformed dictionary. """ raise NotImplementedError + + @abstractmethod + def backward(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Pull back the input x to transformed coordinate y. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + raise NotImplementedError def propagate_name(self, x: list[str]) -> list[str]: input_set = set(x) @@ -88,12 +123,28 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: x[self.name_mapping[1][0]] = output_params return x, jnp.log(jacobian) + def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]: + output_params = x.pop(self.name_mapping[1][0]) + assert_rank(output_params, 0) + input_params = self.inverse_func(output_params) + jacobian = jax.jacfwd(self.inverse_func)(output_params) + x[self.name_mapping[0][0]] = input_params + return x, jnp.log(jacobian) + def forward(self, x: dict[str, Float]) -> dict[str, Float]: input_params = x.pop(self.name_mapping[0][0]) assert_rank(input_params, 0) output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x + + def backward(self, x: dict[str, Float]) -> dict[str, Float]: + output_params = x.pop(self.name_mapping[1][0]) + assert_rank(output_params, 0) + input_params = self.inverse_func(output_params) + x[self.name_mapping[0][0]] = input_params + return x + class Scale(UnivariateTransform): scale: Float @@ -106,6 +157,7 @@ def __init__( super().__init__(name_mapping) self.scale = scale self.transform_func = lambda x: x * self.scale + self.inverse_func = lambda x: x / self.scale class Offset(UnivariateTransform): offset: Float @@ -118,7 +170,7 @@ def __init__( super().__init__(name_mapping) self.offset = offset self.transform_func = lambda x: x + self.offset - + self.inverse_func = lambda x: x - self.offset class Logit(UnivariateTransform): """ @@ -137,7 +189,7 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) - + self.inverse_func = lambda x: jnp.log(x / (1 - x)) class Sine(UnivariateTransform): """ From c16eef5bec5b04aa7b2bddc8788dabfd6f801c06 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 17:13:11 -0400 Subject: [PATCH 016/172] Add comments to current sequential transform prior class --- src/jimgw/prior.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 35722707..9fee29d9 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -131,6 +131,10 @@ def sample( return output def log_prob(self, x: dict[str, Float]) -> Float: + """ + Requiring inverse transform in log_prob may not be the best option, + may need alternative + """ output = 0.0 for transform in reversed(self.transforms): x, log_jacobian = transform.inverse_transform(x) From 8ab92a481aac134ecef29883ea04eae725242fb3 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 13:13:34 -0400 Subject: [PATCH 017/172] Removing inverse. Inverse limts the type of transform one can use, And it doesn't seem to have case that will require log_prob on target space --- src/jimgw/prior.py | 14 +++++------ src/jimgw/transforms.py | 56 ----------------------------------------- 2 files changed, 7 insertions(+), 63 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 9fee29d9..61202609 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -132,14 +132,14 @@ def sample( def log_prob(self, x: dict[str, Float]) -> Float: """ - Requiring inverse transform in log_prob may not be the best option, - may need alternative + log_prob has to be evaluated in the space of the base_prior. + + """ - output = 0.0 - for transform in reversed(self.transforms): - x, log_jacobian = transform.inverse_transform(x) - output += log_jacobian - output += self.base_prior.log_prob(x) + output = self.base_prior.log_prob(x) + for transform in self.transforms: + x, log_jacobian = transform.transform(x) + output -= log_jacobian return output # class Combine(Prior): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index c5e9120a..7ef8560a 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -8,7 +8,6 @@ from chex import assert_rank from jaxtyping import Array, Float, jaxtyped - class Transform(ABC): """ Base class for transform. @@ -18,8 +17,6 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] - inverse_func: Callable[[dict[str, Float]], dict[str, Float]] - def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -49,23 +46,6 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ raise NotImplementedError - @abstractmethod - def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Inverse transform the input x to transformed coordinate y. - - Parameters - ---------- - x : dict[str, Float] - The input dictionary. - - Returns - ------- - y : dict[str, Float] - The transformed dictionary. - """ - raise NotImplementedError - @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -82,23 +62,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: The transformed dictionary. """ raise NotImplementedError - - @abstractmethod - def backward(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Pull back the input x to transformed coordinate y. - - Parameters - ---------- - x : dict[str, Float] - The input dictionary. - - Returns - ------- - y : dict[str, Float] - The transformed dictionary. - """ - raise NotImplementedError def propagate_name(self, x: list[str]) -> list[str]: input_set = set(x) @@ -123,14 +86,6 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: x[self.name_mapping[1][0]] = output_params return x, jnp.log(jacobian) - def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]: - output_params = x.pop(self.name_mapping[1][0]) - assert_rank(output_params, 0) - input_params = self.inverse_func(output_params) - jacobian = jax.jacfwd(self.inverse_func)(output_params) - x[self.name_mapping[0][0]] = input_params - return x, jnp.log(jacobian) - def forward(self, x: dict[str, Float]) -> dict[str, Float]: input_params = x.pop(self.name_mapping[0][0]) assert_rank(input_params, 0) @@ -138,14 +93,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: x[self.name_mapping[1][0]] = output_params return x - def backward(self, x: dict[str, Float]) -> dict[str, Float]: - output_params = x.pop(self.name_mapping[1][0]) - assert_rank(output_params, 0) - input_params = self.inverse_func(output_params) - x[self.name_mapping[0][0]] = input_params - return x - - class Scale(UnivariateTransform): scale: Float @@ -157,7 +104,6 @@ def __init__( super().__init__(name_mapping) self.scale = scale self.transform_func = lambda x: x * self.scale - self.inverse_func = lambda x: x / self.scale class Offset(UnivariateTransform): offset: Float @@ -170,7 +116,6 @@ def __init__( super().__init__(name_mapping) self.offset = offset self.transform_func = lambda x: x + self.offset - self.inverse_func = lambda x: x - self.offset class Logit(UnivariateTransform): """ @@ -189,7 +134,6 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) - self.inverse_func = lambda x: jnp.log(x / (1 - x)) class Sine(UnivariateTransform): """ From d8b2d1feb0bdae6b4fd59bef8cc345eca2bb8c33 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 13:48:17 -0400 Subject: [PATCH 018/172] Add transformation function --- src/jimgw/prior.py | 19 +++++++++++++++---- test/test_prior.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 61202609..232b4a0f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -126,9 +126,7 @@ def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) - for transform in self.transforms: - output = jax.vmap(transform.forward)(output) - return output + return jax.vmap(self.transform)(output) def log_prob(self, x: dict[str, Float]) -> Float: """ @@ -141,6 +139,11 @@ def log_prob(self, x: dict[str, Float]) -> Float: x, log_jacobian = transform.transform(x) output -= log_jacobian return output + + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + for transform in self.transforms: + x = transform.forward(x) + return x # class Combine(Prior): # """ @@ -185,7 +188,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped(typechecker=typechecker) class Uniform(Prior): - _dist: Prior + _dist: SequentialTransform xmin: float xmax: float @@ -218,6 +221,14 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: return self._dist.log_prob(x) + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self._dist.base_prior.sample(rng_key, n_samples) + + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + return self._dist.transform(x) # ====================== Things below may need rework ====================== diff --git a/test/test_prior.py b/test/test_prior.py index 6890f9b9..b6eb7c87 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -7,7 +7,7 @@ def test_logistic(self): def test_uniform(self): p = Uniform(0.0, 10.0, ['x']) - samples = p.sample(jax.random.PRNGKey(0), 10000) + samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, -jnp.log(10.0)) From ca1d6b6c1fa2ac7cae0d9a313693ea97af070fe4 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 14:07:16 -0400 Subject: [PATCH 019/172] Combine should be working now --- src/jimgw/prior.py | 75 +++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 232b4a0f..85414064 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -145,44 +145,43 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]: x = transform.forward(x) return x -# class Combine(Prior): -# """ -# A prior class constructed by joinning multiple priors together to form a multivariate prior. -# This assumes the priors composing the Combine class are independent. -# """ - -# priors: list[Prior] = field(default_factory=list) - -# def __repr__(self): -# return ( -# f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" -# ) - -# def __init__( -# self, -# priors: list[Prior], -# **kwargs, -# ): -# parameter_names = [] -# for prior in priors: -# parameter_names += prior.parameter_names -# self.priors = priors -# self.parameter_names = parameter_names - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# output = {} -# for prior in self.priors: -# rng_key, subkey = jax.random.split(rng_key) -# output.update(prior.sample(subkey, n_samples)) -# return output - -# def log_prob(self, x: dict[str, Float]) -> Float: -# output = 0.0 -# for prior in self.priors: -# output -= prior.log_prob(x) -# return output +class Combine(Prior): + """ + A prior class constructed by joinning multiple priors together to form a multivariate prior. + This assumes the priors composing the Combine class are independent. + """ + + priors: list[Prior] = field(default_factory=list) + + def __repr__(self): + return ( + f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" + ) + + def __init__( + self, + priors: list[Prior], + ): + parameter_names = [] + for prior in priors: + parameter_names += prior.parameter_names + self.priors = priors + self.parameter_names = parameter_names + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + output = {} + for prior in self.priors: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output + + def log_prob(self, x: dict[str, Float]) -> Float: + output = 0.0 + for prior in self.priors: + output += prior.log_prob(x) + return output From 2f3f412446ec346a8dbfe367dc7acfee7bf4c7d1 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 14:52:02 -0400 Subject: [PATCH 020/172] Sine is an illegal transform since its Jacobian could be negative --- src/jimgw/prior.py | 2 +- src/jimgw/transforms.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 85414064..1e7cb960 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -193,7 +193,7 @@ class Uniform(Prior): xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax})" + return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" def __init__( self, diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 7ef8560a..55a5d0ce 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -135,10 +135,10 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) -class Sine(UnivariateTransform): +class ArcSine(UnivariateTransform): """ - Transform from unconstrained space to uniform space. - + ArcSine transformation + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -151,4 +151,4 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: jnp.sin(x) + self.transform_func = lambda x: jnp.arcsin(x) From 4978cb1bd8fbf7e1083d0552ebe198ab44de1040 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 25 Jul 2024 18:24:33 -0400 Subject: [PATCH 021/172] Modify Uniform and add UniformSphere --- src/jimgw/prior.py | 143 +++++++++++++++------------------------- src/jimgw/transforms.py | 37 +++++++++-- 2 files changed, 82 insertions(+), 98 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 1e7cb960..effbf388 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcCosine class Prior(Distribution): @@ -49,7 +49,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ return dict(zip(self.parameter_names, x)) - + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -106,9 +106,7 @@ class SequentialTransform(Prior): transforms: list[Transform] def __repr__(self): - return ( - f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" - ) + return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" def __init__( self, @@ -127,24 +125,28 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - + def log_prob(self, x: dict[str, Float]) -> Float: """ log_prob has to be evaluated in the space of the base_prior. - - """ output = self.base_prior.log_prob(x) for transform in self.transforms: x, log_jacobian = transform.transform(x) output -= log_jacobian return output - + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self.base_prior.sample(rng_key, n_samples) + def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) return x + class Combine(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. @@ -184,16 +186,13 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output - @jaxtyped(typechecker=typechecker) -class Uniform(Prior): - _dist: SequentialTransform - +class Uniform(SequentialTransform): xmin: float xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" def __init__( self, @@ -201,94 +200,56 @@ def __init__( xmax: float, parameter_names: list[str], ): - super().__init__(parameter_names) + self.parameter_names = parameter_names assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - self._dist = SequentialTransform( - LogisticDistribution(parameter_names), + super().__init__( + LogisticDistribution(self.parameter_names), [ - Logit((parameter_names, parameter_names)), - Scale((parameter_names, parameter_names), xmax - xmin), - Offset((parameter_names, parameter_names), xmin), - ]) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self._dist.sample(rng_key, n_samples) - - def log_prob(self, x: dict[str, Array]) -> Float: - return self._dist.log_prob(x) - - def sample_base( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self._dist.base_prior.sample(rng_key, n_samples) - - def transform(self, x: dict[str, Float]) -> dict[str, Float]: - return self._dist.transform(x) - -# ====================== Things below may need rework ====================== + Logit((self.parameter_names, self.parameter_names)), + Scale((self.parameter_names, self.parameter_names), xmax - xmin), + Offset((self.parameter_names, self.parameter_names), xmin), + ], + ) -class Sphere(Prior): - """ - A prior on a sphere represented by Cartesian coordinates. - Magnitude is sampled from a uniform distribution. - """ +@jaxtyped(typechecker=typechecker) +class UniformSphere(Combine): def __repr__(self): - return f"Sphere(naming={self.naming})" + return f"UniformSphere(parameter_names={self.parameter_names})" - def __init__(self, naming: list[str], **kwargs): - name = naming[0] - self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] - self.transforms = { - self.naming[0]: ( - f"{naming}_x", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.cos(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[1]: ( - f"{naming}_y", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.sin(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[2]: ( - f"{naming}_z", - lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 3) - theta = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) + def __init__(self, parameter_names: list[str], **kwargs): + assert ( + len(parameter_names) == 1 + ), "UniformSphere only takes the name of the vector" + parameter_names = parameter_names[0] + self.parameter_names = [ + f"{parameter_names}_mag", + f"{parameter_names}_theta", + f"{parameter_names}_phi", + ] + super().__init__( + [ + Uniform(0.0, 1.0, [self.parameter_names[0]]), + SequentialTransform( + Uniform(-1.0, 1.0, [f"cos_{self.parameter_names[1]}"]), + [ + ArcCosine( + ( + [f"cos_{self.parameter_names[1]}"], + [self.parameter_names[1]], + ) + ) + ], + ), + Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + ] ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) - mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) - return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - def log_prob(self, x: dict[str, Float]) -> Float: - theta = x[self.naming[0]] - phi = x[self.naming[1]] - mag = x[self.naming[2]] - output = jnp.where( - (mag > 1) - | (mag < 0) - | (phi > 2 * jnp.pi) - | (phi < 0) - | (theta > jnp.pi) - | (theta < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), - ) - return output + +# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 55a5d0ce..d8ee9536 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from dataclasses import field -from typing import Callable, Union +from typing import Callable import jax import jax.numpy as jnp -from beartype import beartype as typechecker from chex import assert_rank -from jaxtyping import Array, Float, jaxtyped +from jaxtyping import Float + class Transform(ABC): """ @@ -17,6 +16,7 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] + def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -92,7 +92,8 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x - + + class Scale(UnivariateTransform): scale: Float @@ -105,6 +106,7 @@ def __init__( self.scale = scale self.transform_func = lambda x: x * self.scale + class Offset(UnivariateTransform): offset: Float @@ -117,6 +119,7 @@ def __init__( self.offset = offset self.transform_func = lambda x: x + self.offset + class Logit(UnivariateTransform): """ Logit transform following @@ -135,10 +138,11 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + class ArcSine(UnivariateTransform): """ ArcSine transformation - + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -152,3 +156,22 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) + + +class ArcCosine(UnivariateTransform): + """ + ArcCosine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: jnp.arccos(x) From c1115bd5b669bc8a3bcc7e30fe5e7405beb4ba97 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 06:27:36 -0400 Subject: [PATCH 022/172] Add Sine and Cosine Prior --- src/jimgw/prior.py | 50 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index effbf388..894ea9f7 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcCosine +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine class Prior(Distribution): @@ -214,6 +214,42 @@ def __init__( ) +@jaxtyped(typechecker=typechecker) +class Sine(SequentialTransform): + """ + A prior distribution where the pdf is proportional to sin(x) in the range [0, pi]. + """ + + def __repr__(self): + return f"Sine(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 1, "Sine needs to be 1D distributions" + super().__init__( + Uniform(-1.0, 1.0, f"cos_{self.parameter_names}"), + [ArcCosine(([f"cos_{self.parameter_names}"], [self.parameter_names]))], + ) + + +@jaxtyped(typechecker=typechecker) +class Cosine(SequentialTransform): + """ + A prior distribution where the pdf is proportional to cos(x) in the range [-pi/2, pi/2]. + """ + + def __repr__(self): + return f"Cosine(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 1, "Cosine needs to be 1D distributions" + super().__init__( + Uniform(-1.0, 1.0, f"sin_{self.parameter_names}"), + [ArcSine(([f"sin_{self.parameter_names}"], [self.parameter_names]))], + ) + + @jaxtyped(typechecker=typechecker) class UniformSphere(Combine): @@ -233,17 +269,7 @@ def __init__(self, parameter_names: list[str], **kwargs): super().__init__( [ Uniform(0.0, 1.0, [self.parameter_names[0]]), - SequentialTransform( - Uniform(-1.0, 1.0, [f"cos_{self.parameter_names[1]}"]), - [ - ArcCosine( - ( - [f"cos_{self.parameter_names[1]}"], - [self.parameter_names[1]], - ) - ) - ], - ), + Sine([self.parameter_names[1]]), Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] ) From 01a6c1e65ebc4bdc53df48bad87afe1f08d47c58 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 06:57:11 -0400 Subject: [PATCH 023/172] Revert "Modify Uniform and add UniformSphere" This reverts commit 4978cb1bd8fbf7e1083d0552ebe198ab44de1040. --- src/jimgw/prior.py | 143 +++++++++++++++++++++++++--------------- src/jimgw/transforms.py | 37 ++--------- 2 files changed, 98 insertions(+), 82 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index effbf388..1e7cb960 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcCosine +from jimgw.transforms import Transform, Logit, Scale, Offset class Prior(Distribution): @@ -49,7 +49,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ return dict(zip(self.parameter_names, x)) - + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -106,7 +106,9 @@ class SequentialTransform(Prior): transforms: list[Transform] def __repr__(self): - return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" + return ( + f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" + ) def __init__( self, @@ -125,28 +127,24 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - + def log_prob(self, x: dict[str, Float]) -> Float: """ log_prob has to be evaluated in the space of the base_prior. + + """ output = self.base_prior.log_prob(x) for transform in self.transforms: x, log_jacobian = transform.transform(x) output -= log_jacobian return output - - def sample_base( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self.base_prior.sample(rng_key, n_samples) - + def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) return x - class Combine(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. @@ -186,13 +184,16 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output + @jaxtyped(typechecker=typechecker) -class Uniform(SequentialTransform): +class Uniform(Prior): + _dist: SequentialTransform + xmin: float xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" + return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" def __init__( self, @@ -200,56 +201,94 @@ def __init__( xmax: float, parameter_names: list[str], ): - self.parameter_names = parameter_names + super().__init__(parameter_names) assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - super().__init__( - LogisticDistribution(self.parameter_names), + self._dist = SequentialTransform( + LogisticDistribution(parameter_names), [ - Logit((self.parameter_names, self.parameter_names)), - Scale((self.parameter_names, self.parameter_names), xmax - xmin), - Offset((self.parameter_names, self.parameter_names), xmin), - ], - ) + Logit((parameter_names, parameter_names)), + Scale((parameter_names, parameter_names), xmax - xmin), + Offset((parameter_names, parameter_names), xmin), + ]) + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self._dist.sample(rng_key, n_samples) -@jaxtyped(typechecker=typechecker) -class UniformSphere(Combine): + def log_prob(self, x: dict[str, Array]) -> Float: + return self._dist.log_prob(x) + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self._dist.base_prior.sample(rng_key, n_samples) + + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + return self._dist.transform(x) + +# ====================== Things below may need rework ====================== + +class Sphere(Prior): + """ + A prior on a sphere represented by Cartesian coordinates. + + Magnitude is sampled from a uniform distribution. + """ def __repr__(self): - return f"UniformSphere(parameter_names={self.parameter_names})" + return f"Sphere(naming={self.naming})" - def __init__(self, parameter_names: list[str], **kwargs): - assert ( - len(parameter_names) == 1 - ), "UniformSphere only takes the name of the vector" - parameter_names = parameter_names[0] - self.parameter_names = [ - f"{parameter_names}_mag", - f"{parameter_names}_theta", - f"{parameter_names}_phi", - ] - super().__init__( - [ - Uniform(0.0, 1.0, [self.parameter_names[0]]), - SequentialTransform( - Uniform(-1.0, 1.0, [f"cos_{self.parameter_names[1]}"]), - [ - ArcCosine( - ( - [f"cos_{self.parameter_names[1]}"], - [self.parameter_names[1]], - ) - ) - ], - ), - Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), - ] - ) + def __init__(self, naming: list[str], **kwargs): + name = naming[0] + self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] + self.transforms = { + self.naming[0]: ( + f"{naming}_x", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.cos(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[1]: ( + f"{naming}_y", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.sin(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[2]: ( + f"{naming}_z", + lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], + ), + } + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + rng_keys = jax.random.split(rng_key, 3) + theta = jnp.arccos( + jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) + ) + phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) + mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) + return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) -# ====================== Things below may need rework ====================== + def log_prob(self, x: dict[str, Float]) -> Float: + theta = x[self.naming[0]] + phi = x[self.naming[1]] + mag = x[self.naming[2]] + output = jnp.where( + (mag > 1) + | (mag < 0) + | (phi > 2 * jnp.pi) + | (phi < 0) + | (theta > jnp.pi) + | (theta < 0), + jnp.zeros_like(0) - jnp.inf, + jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), + ) + return output @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8ee9536..55a5d0ce 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from typing import Callable +from dataclasses import field +from typing import Callable, Union import jax import jax.numpy as jnp +from beartype import beartype as typechecker from chex import assert_rank -from jaxtyping import Float - +from jaxtyping import Array, Float, jaxtyped class Transform(ABC): """ @@ -16,7 +17,6 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -92,8 +92,7 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x - - + class Scale(UnivariateTransform): scale: Float @@ -106,7 +105,6 @@ def __init__( self.scale = scale self.transform_func = lambda x: x * self.scale - class Offset(UnivariateTransform): offset: Float @@ -119,7 +117,6 @@ def __init__( self.offset = offset self.transform_func = lambda x: x + self.offset - class Logit(UnivariateTransform): """ Logit transform following @@ -138,11 +135,10 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) - class ArcSine(UnivariateTransform): """ ArcSine transformation - + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -156,22 +152,3 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) - - -class ArcCosine(UnivariateTransform): - """ - ArcCosine transformation - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - self.transform_func = lambda x: jnp.arccos(x) From 7ee41335d7820f21f1b777d0482c39a989a5f65a Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:12:10 -0400 Subject: [PATCH 024/172] Add standard normal distribution --- src/jimgw/prior.py | 96 +++++++++++++++++++--------------------------- 1 file changed, 39 insertions(+), 57 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 894ea9f7..3c411c42 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -63,17 +63,17 @@ def log_prob(self, x: dict[str, Array]) -> Float: class LogisticDistribution(Prior): def __repr__(self): - return f"Logistic(parameter_names={self.parameter_names})" + return f"LogisticDistribution(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) - assert self.n_dim == 1, "Logit needs to be 1D distributions" + assert self.n_dim == 1, "LogisticDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: """ - Sample from a logit distribution. + Sample from a logistic distribution. Parameters ---------- @@ -97,6 +97,42 @@ def log_prob(self, x: dict[str, Float]) -> Float: return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) +@jaxtyped(typechecker=typechecker) +class StandardNormalDistribution(Prior): + + def __repr__(self): + return f"StandardNormalDistribution(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str], **kwargs): + super().__init__(parameter_names) + assert self.n_dim == 1, "StandardNormalDistribution needs to be 1D distributions" + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a standard normal distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.normal(rng_key, (n_samples,)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Float]) -> Float: + variable = x[self.parameter_names[0]] + return -0.5 * variable ** 2 - 0.5 * jnp.log(2 * jnp.pi) + class SequentialTransform(Prior): """ Transform a prior distribution by applying a sequence of transforms. @@ -624,57 +660,3 @@ def log_prob(self, x: dict[str, Float]) -> Float: ) log_p = self.alpha * variable + jnp.log(self.normalization) return log_p + log_in_range - - -@jaxtyped(typechecker=typechecker) -class Normal(Prior): - mean: Float = 0.0 - std: Float = 1.0 - - def __repr__(self): - return f"Normal(mean={self.mean}, std={self.std})" - - def __init__( - self, - mean: Float, - std: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Normal needs to be 1D distributions" - self.mean = mean - self.std = std - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a normal distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - samples = jax.random.normal(rng_key, (n_samples,)) - samples = self.mean + samples * self.std - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = ( - -0.5 * jnp.log(2 * jnp.pi) - - jnp.log(self.std) - - 0.5 * ((variable - self.mean) / self.std) ** 2 - ) - return output From 6a3579296b9cd646d815f49897454af2fc57d7e2 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:18:22 -0400 Subject: [PATCH 025/172] Add periodic uniform prior --- src/jimgw/prior.py | 31 +++++++++++++++++++++++++++++-- src/jimgw/transforms.py | 19 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3c411c42..e2d98f71 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine, Modulo class Prior(Distribution): @@ -249,6 +249,33 @@ def __init__( ], ) +@jaxtyped(typechecker=typechecker) +class PeriodicUniform(SequentialTransform): + xmin: float + xmax: float + + def __repr__(self): + return f"PeriodicUniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" + + def __init__( + self, + xmin: float, + xmax: float, + parameter_names: list[str], + ): + self.parameter_names = parameter_names + assert self.n_dim == 1, "PeriodicUniform needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + super().__init__( + LogisticDistribution(self.parameter_names), + [ + Logit((self.parameter_names, self.parameter_names)), + Modulo((self.parameter_names, self.parameter_names), xmax - xmin), + Offset((self.parameter_names, self.parameter_names), xmin), + ], + ) + @jaxtyped(typechecker=typechecker) class Sine(SequentialTransform): @@ -306,7 +333,7 @@ def __init__(self, parameter_names: list[str], **kwargs): [ Uniform(0.0, 1.0, [self.parameter_names[0]]), Sine([self.parameter_names[1]]), - Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + PeriodicUniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8ee9536..05dacdb1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -138,6 +138,25 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) +class Modulo(UnivariateTransform): + """ + Modulo transform following + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + modulo: Float, + ): + super().__init__(name_mapping) + self.modulo = modulo + self.transform_func = lambda x: jnp.mod(x, self.modulo) class ArcSine(UnivariateTransform): """ From 1bcf32c239c73ef904bda3b1aff86e761dc66bf9 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:20:53 -0400 Subject: [PATCH 026/172] Reformat --- src/jimgw/prior.py | 8 ++++++-- src/jimgw/transforms.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e2d98f71..7dd50018 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -105,7 +105,9 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) - assert self.n_dim == 1, "StandardNormalDistribution needs to be 1D distributions" + assert ( + self.n_dim == 1 + ), "StandardNormalDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int @@ -131,7 +133,8 @@ def sample( def log_prob(self, x: dict[str, Float]) -> Float: variable = x[self.parameter_names[0]] - return -0.5 * variable ** 2 - 0.5 * jnp.log(2 * jnp.pi) + return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) + class SequentialTransform(Prior): """ @@ -249,6 +252,7 @@ def __init__( ], ) + @jaxtyped(typechecker=typechecker) class PeriodicUniform(SequentialTransform): xmin: float diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 05dacdb1..8a4787c5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -138,6 +138,7 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + class Modulo(UnivariateTransform): """ Modulo transform following @@ -158,6 +159,7 @@ def __init__( self.modulo = modulo self.transform_func = lambda x: jnp.mod(x, self.modulo) + class ArcSine(UnivariateTransform): """ ArcSine transformation From 5d98aeb529c165fccde0d14400e6ccb0408fdfd8 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:44:30 -0400 Subject: [PATCH 027/172] Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-class' into sphere_prior" This reverts commit 17450be4adc3c03d24d23c8ccbb6f9af418980e7, reversing changes made to 1bcf32c239c73ef904bda3b1aff86e761dc66bf9. --- src/jimgw/prior.py | 72 +++++++++++++---------------------------- src/jimgw/transforms.py | 36 +++++++++++++++++---- 2 files changed, 51 insertions(+), 57 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 74cc39dd..7dd50018 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -49,7 +49,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ return dict(zip(self.parameter_names, x)) - + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -145,9 +145,7 @@ class SequentialTransform(Prior): transforms: list[Transform] def __repr__(self): - return ( - f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" - ) + return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" def __init__( self, @@ -166,24 +164,28 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - + def log_prob(self, x: dict[str, Float]) -> Float: """ log_prob has to be evaluated in the space of the base_prior. - - """ output = self.base_prior.log_prob(x) for transform in self.transforms: x, log_jacobian = transform.transform(x) output -= log_jacobian return output - + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self.base_prior.sample(rng_key, n_samples) + def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) return x + class Combine(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. @@ -223,16 +225,13 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output - @jaxtyped(typechecker=typechecker) -class Uniform(Prior): - _dist: SequentialTransform - +class Uniform(SequentialTransform): xmin: float xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" def __init__( self, @@ -240,22 +239,19 @@ def __init__( xmax: float, parameter_names: list[str], ): - super().__init__(parameter_names) + self.parameter_names = parameter_names assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - self._dist = SequentialTransform( - LogisticDistribution(parameter_names), + super().__init__( + LogisticDistribution(self.parameter_names), [ - Logit((parameter_names, parameter_names)), - Scale((parameter_names, parameter_names), xmax - xmin), - Offset((parameter_names, parameter_names), xmin), - ]) + Logit((self.parameter_names, self.parameter_names)), + Scale((self.parameter_names, self.parameter_names), xmax - xmin), + Offset((self.parameter_names, self.parameter_names), xmin), + ], + ) - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self._dist.sample(rng_key, n_samples) @jaxtyped(typechecker=typechecker) class PeriodicUniform(SequentialTransform): @@ -325,7 +321,7 @@ def __init__(self, parameter_names: list[str]): class UniformSphere(Combine): def __repr__(self): - return f"Sphere(naming={self.naming})" + return f"UniformSphere(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): assert ( @@ -345,32 +341,8 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 3) - theta = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) - mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) - return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - def log_prob(self, x: dict[str, Float]) -> Float: - theta = x[self.naming[0]] - phi = x[self.naming[1]] - mag = x[self.naming[2]] - output = jnp.where( - (mag > 1) - | (mag < 0) - | (phi > 2 * jnp.pi) - | (phi < 0) - | (theta > jnp.pi) - | (theta < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), - ) - return output +# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index da8c1591..8a4787c5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from dataclasses import field -from typing import Callable, Union +from typing import Callable import jax import jax.numpy as jnp -from beartype import beartype as typechecker from chex import assert_rank -from jaxtyping import Array, Float, jaxtyped +from jaxtyping import Float + class Transform(ABC): """ @@ -17,6 +16,7 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] + def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -92,7 +92,8 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x - + + class Scale(UnivariateTransform): scale: Float @@ -105,6 +106,7 @@ def __init__( self.scale = scale self.transform_func = lambda x: x * self.scale + class Offset(UnivariateTransform): offset: Float @@ -117,6 +119,7 @@ def __init__( self.offset = offset self.transform_func = lambda x: x + self.offset + class Logit(UnivariateTransform): """ Logit transform following @@ -160,7 +163,7 @@ def __init__( class ArcSine(UnivariateTransform): """ ArcSine transformation - + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -174,3 +177,22 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) + + +class ArcCosine(UnivariateTransform): + """ + ArcCosine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: jnp.arccos(x) From 87ee212ac44d5afa09c40312c4d5570b735ae2de Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:47:00 -0400 Subject: [PATCH 028/172] Minor text change --- src/jimgw/prior.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 7dd50018..94363f4c 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -195,9 +195,7 @@ class Combine(Prior): priors: list[Prior] = field(default_factory=list) def __repr__(self): - return ( - f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" - ) + return f"Combine(priors={self.priors}, parameter_names={self.parameter_names})" def __init__( self, From cc1448e2ec120828e5bc1b715cd2ae4c6e01563b Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 08:28:48 -0400 Subject: [PATCH 029/172] Remove PeriodicUniform --- src/jimgw/prior.py | 30 +----------------------------- src/jimgw/transforms.py | 21 --------------------- 2 files changed, 1 insertion(+), 50 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 94363f4c..c571c7dc 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine, Modulo +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine class Prior(Distribution): @@ -251,34 +251,6 @@ def __init__( ) -@jaxtyped(typechecker=typechecker) -class PeriodicUniform(SequentialTransform): - xmin: float - xmax: float - - def __repr__(self): - return f"PeriodicUniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" - - def __init__( - self, - xmin: float, - xmax: float, - parameter_names: list[str], - ): - self.parameter_names = parameter_names - assert self.n_dim == 1, "PeriodicUniform needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - super().__init__( - LogisticDistribution(self.parameter_names), - [ - Logit((self.parameter_names, self.parameter_names)), - Modulo((self.parameter_names, self.parameter_names), xmax - xmin), - Offset((self.parameter_names, self.parameter_names), xmin), - ], - ) - - @jaxtyped(typechecker=typechecker) class Sine(SequentialTransform): """ diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 8a4787c5..d8ee9536 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -139,27 +139,6 @@ def __init__( self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) -class Modulo(UnivariateTransform): - """ - Modulo transform following - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - modulo: Float, - ): - super().__init__(name_mapping) - self.modulo = modulo - self.transform_func = lambda x: jnp.mod(x, self.modulo) - - class ArcSine(UnivariateTransform): """ ArcSine transformation From 6070f130511caf40fde0a4b00ea87f430d80d653 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 08:41:51 -0400 Subject: [PATCH 030/172] Use self.sample_base --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index c571c7dc..53271680 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -162,7 +162,7 @@ def __init__( def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - output = self.base_prior.sample(rng_key, n_samples) + output = self.sample_base(rng_key, n_samples) return jax.vmap(self.transform)(output) def log_prob(self, x: dict[str, Float]) -> Float: From 8a301f2759b7ddc3bb765d4ca8a00de3dd88994c Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 12:27:25 -0400 Subject: [PATCH 031/172] Update prior.py --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 53271680..de378973 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -307,7 +307,7 @@ def __init__(self, parameter_names: list[str], **kwargs): [ Uniform(0.0, 1.0, [self.parameter_names[0]]), Sine([self.parameter_names[1]]), - PeriodicUniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] ) From d761f6bfb83da27d83813f4cb1fce44a2882ebac Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 12:47:01 -0400 Subject: [PATCH 032/172] Update prior.py to include powerLaw --- src/jimgw/prior.py | 191 +++++++++++++++++++++++++++------------------ 1 file changed, 113 insertions(+), 78 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index de378973..637bbc2b 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine, PowerLawTransform, ParetoTransform class Prior(Distribution): @@ -311,6 +311,41 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) +@jaxtyped(typechecker=typechecker) +class PowerLaw(SequentialTransform): + xmin: float + xmax: float + alpha: float + + def __repr__(self): + return f"PowerLaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" + + def __init__( + self, + xmin: float, + xmax: float, + alpha: float, + parameter_names: list[str], + ): + self.parameter_names = parameter_names + assert self.n_dim == 1, "Power law needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + if self.alpha == -1.0: + transform = ParetoTransform((self.parameter_names, self.parameter_names), xmin, xmax) + else: + transform = PowerLawTransform( + (self.parameter_names, self.parameter_names), xmin, xmax, alpha + ) + super().__init__( + LogisticDistribution(self.parameter_names), + [ + Logit((self.parameter_names, self.parameter_names)), + transform, + ], + ) + # ====================== Things below may need rework ====================== @@ -504,83 +539,83 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output + jnp.log(jnp.sin(zenith)) -@jaxtyped(typechecker=typechecker) -class PowerLaw(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ x^{\alpha} - """ - - xmin: float = 0.0 - xmax: float = 1.0 - alpha: float = 0.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: float, - xmax: float, - alpha: Union[Int, float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin > 0.0, "With negative alpha, xmin must > 0" - assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - if alpha == -1: - self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) - else: - self.normalization = (1 + self.alpha) / ( - self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a power-law distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - if self.alpha == -1: - samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) - else: - samples = ( - self.xmin ** (1.0 + self.alpha) - + q_samples - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) - return log_p + log_in_range +# @jaxtyped(typechecker=typechecker) +# class PowerLaw(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ x^{\alpha} +# """ + +# xmin: float = 0.0 +# xmax: float = 1.0 +# alpha: float = 0.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: float, +# xmax: float, +# alpha: Union[Int, float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin > 0.0, "With negative alpha, xmin must > 0" +# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha +# if alpha == -1: +# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) +# else: +# self.normalization = (1 + self.alpha) / ( +# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a power-law distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# if self.alpha == -1: +# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) +# else: +# samples = ( +# self.xmin ** (1.0 + self.alpha) +# + q_samples +# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) +# return log_p + log_in_range @jaxtyped(typechecker=typechecker) From b9a725506012f80db3d297fe66f8e73fc632d37e Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 12:47:51 -0400 Subject: [PATCH 033/172] Update transforms.py --- src/jimgw/transforms.py | 55 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8ee9536..2feac5a2 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -175,3 +175,58 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arccos(x) + + +class PowerLawTransform(UnivariateTransform): + """ + PowerLaw transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + xmin: Float + xmax: Float + alpha: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + alpha: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.alpha = alpha + self.transform_func = lambda x: (self.xmin ** (1.0 + self.alpha)+ x* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)))** (1.0 / (1.0 + self.alpha)), + + + +class ParetoTransform(UnivariateTransform): + """ + Pareto transformation: Power law when alpha = -1 + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.transform_func = lambda x: self.xmin * jnp.exp( + x * jnp.log(self.xmax / self.xmin) + ) From 2f6e12ac73915ccf0e2081ef241f0dc3cf9eb53e Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 16:15:41 -0400 Subject: [PATCH 034/172] format and updating typing hint --- src/jimgw/prior.py | 945 ++++++++++++++++++++-------------------- src/jimgw/transforms.py | 20 +- 2 files changed, 483 insertions(+), 482 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 1e7cb960..77d5713b 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,15 +1,11 @@ from dataclasses import field -from typing import Callable, Union import jax import jax.numpy as jnp -from astropy.time import Time from beartype import beartype as typechecker from flowMC.nfmodel.base import Distribution -from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped +from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped -from jimgw.single_event.detector import GroundBased2G, detector_preset -from jimgw.single_event.utils import zenith_azimuth_to_ra_dec from jimgw.transforms import Transform, Logit, Scale, Offset @@ -49,7 +45,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ return dict(zip(self.parameter_names, x)) - + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -106,9 +102,7 @@ class SequentialTransform(Prior): transforms: list[Transform] def __repr__(self): - return ( - f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" - ) + return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" def __init__( self, @@ -127,7 +121,7 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - + def log_prob(self, x: dict[str, Float]) -> Float: """ log_prob has to be evaluated in the space of the base_prior. @@ -139,12 +133,13 @@ def log_prob(self, x: dict[str, Float]) -> Float: x, log_jacobian = transform.transform(x) output -= log_jacobian return output - + def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) return x + class Combine(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. @@ -184,7 +179,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output - @jaxtyped(typechecker=typechecker) class Uniform(Prior): _dist: SequentialTransform @@ -211,7 +205,8 @@ def __init__( Logit((parameter_names, parameter_names)), Scale((parameter_names, parameter_names), xmax - xmin), Offset((parameter_names, parameter_names), xmin), - ]) + ], + ) def sample( self, rng_key: PRNGKeyArray, n_samples: int @@ -220,474 +215,476 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: return self._dist.log_prob(x) - + def sample_base( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: return self._dist.base_prior.sample(rng_key, n_samples) - + def transform(self, x: dict[str, Float]) -> dict[str, Float]: return self._dist.transform(x) -# ====================== Things below may need rework ====================== - -class Sphere(Prior): - """ - A prior on a sphere represented by Cartesian coordinates. - - Magnitude is sampled from a uniform distribution. - """ - - def __repr__(self): - return f"Sphere(naming={self.naming})" - - def __init__(self, naming: list[str], **kwargs): - name = naming[0] - self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] - self.transforms = { - self.naming[0]: ( - f"{naming}_x", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.cos(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[1]: ( - f"{naming}_y", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.sin(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[2]: ( - f"{naming}_z", - lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 3) - theta = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) - mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) - return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - theta = x[self.naming[0]] - phi = x[self.naming[1]] - mag = x[self.naming[2]] - output = jnp.where( - (mag > 1) - | (mag < 0) - | (phi > 2 * jnp.pi) - | (phi < 0) - | (theta > jnp.pi) - | (theta < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), - ) - return output - - -@jaxtyped(typechecker=typechecker) -class AlignedSpin(Prior): - """ - Prior distribution for the aligned (z) component of the spin. - - This assume the prior distribution on the spin magnitude to be uniform in [0, amax] - with its orientation uniform on a sphere - - p(chi) = -log(|chi| / amax) / 2 / amax - - This is useful when comparing results between an aligned-spin run and - a precessing spin run. - - See (A7) of https://arxiv.org/abs/1805.10457. - """ - - amax: Float = 0.99 - chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - - def __repr__(self): - return f"Alignedspin(amax={self.amax}, naming={self.naming})" - - def __init__( - self, - amax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" - self.amax = amax - - # build the interpolation table for the ppf of the one-sided distribution - chi_axis = jnp.linspace(1e-31, self.amax, num=1000) - cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax - self.chi_axis = chi_axis - self.cdf_vals = cdf_vals - - @property - def xmin(self): - return -self.amax - - @property - def xmax(self): - return self.amax - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from the Alignedspin distribution. - for chi > 0; - p(chi) = -log(chi / amax) / amax # halved normalization constant - cdf(chi) = -chi * (log(chi / amax) - 1) / amax - - Since there is a pole at chi=0, we will sample with the following steps - 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise - 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) - 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) - 3. Map the quantile to chi via the ppf by checking against the table - built during the initialization - 4. add back the sign - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - # 1. calculate the sign of chi from the q_samples - sign_samples = jnp.where( - q_samples >= 0.5, - jnp.zeros_like(q_samples) + 1.0, - jnp.zeros_like(q_samples) - 1.0, - ) - # 2. remap q_samples - q_samples = jnp.where( - q_samples >= 0.5, - 2 * (q_samples - 0.5), - 2 * (0.5 - q_samples), - ) - # 3. map the quantile to chi via interpolation - samples = jnp.interp( - q_samples, - self.cdf_vals, - self.chi_axis, - ) - # 4. add back the sign - samples *= sign_samples - - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_p = jnp.where( - (variable >= self.amax) | (variable <= -self.amax), - jnp.zeros_like(variable) - jnp.inf, - jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), - ) - return log_p - - -@jaxtyped(typechecker=typechecker) -class EarthFrame(Prior): - """ - Prior distribution for sky location in Earth frame. - """ - - ifos: list = field(default_factory=list) - gmst: float = 0.0 - delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) - - def __repr__(self): - return f"EarthFrame(naming={self.naming})" - - def __init__(self, gps: Float, ifos: list, **kwargs): - self.naming = ["zenith", "azimuth"] - if len(ifos) < 2: - return ValueError( - "At least two detectors are needed to define the Earth frame" - ) - elif isinstance(ifos[0], str): - self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] - elif isinstance(ifos[0], GroundBased2G): - self.ifos = ifos[:1] - else: - return ValueError( - "ifos should be a list of detector names or GroundBased2G objects" - ) - self.gmst = float( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad - ) - self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - - self.transforms = { - "azimuth": ( - "ra", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[0], - ), - "zenith": ( - "dec", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[1], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 2) - zenith = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - azimuth = jax.random.uniform( - rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi - ) - return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - zenith = x["zenith"] - azimuth = x["azimuth"] - output = jnp.where( - (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.zeros_like(0), - ) - return output + jnp.log(jnp.sin(zenith)) - - -@jaxtyped(typechecker=typechecker) -class PowerLaw(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ x^{\alpha} - """ - - xmin: float = 0.0 - xmax: float = 1.0 - alpha: float = 0.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: float, - xmax: float, - alpha: Union[Int, float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin > 0.0, "With negative alpha, xmin must > 0" - assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - if alpha == -1: - self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) - else: - self.normalization = (1 + self.alpha) / ( - self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a power-law distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - if self.alpha == -1: - samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) - else: - samples = ( - self.xmin ** (1.0 + self.alpha) - + q_samples - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) - return log_p + log_in_range - - -@jaxtyped(typechecker=typechecker) -class Exponential(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ exp(\alpha x) - """ - - xmin: float = 0.0 - xmax: float = jnp.inf - alpha: float = -1.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: Float, - xmax: Float, - alpha: Union[Int, Float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin != -jnp.inf, "With negative alpha, xmin must finite" - if alpha > 0.0: - assert xmax != jnp.inf, "With positive alpha, xmax must finite" - assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" - assert self.n_dim == 1, "Exponential needs to be 1D distributions" - - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - - self.normalization = self.alpha / ( - jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a exponential distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - samples = ( - self.xmin - + jnp.log1p( - q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) - ) - / self.alpha - ) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * variable + jnp.log(self.normalization) - return log_p + log_in_range - - -@jaxtyped(typechecker=typechecker) -class Normal(Prior): - mean: Float = 0.0 - std: Float = 1.0 - - def __repr__(self): - return f"Normal(mean={self.mean}, std={self.std})" - - def __init__( - self, - mean: Float, - std: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Normal needs to be 1D distributions" - self.mean = mean - self.std = std - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a normal distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. +# ====================== Things below may need rework ====================== - """ - samples = jax.random.normal(rng_key, (n_samples,)) - samples = self.mean + samples * self.std - return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = ( - -0.5 * jnp.log(2 * jnp.pi) - - jnp.log(self.std) - - 0.5 * ((variable - self.mean) / self.std) ** 2 - ) - return output +# class Sphere(Prior): +# """ +# A prior on a sphere represented by Cartesian coordinates. + +# Magnitude is sampled from a uniform distribution. +# """ + +# def __repr__(self): +# return f"Sphere(naming={self.naming})" + +# def __init__(self, naming: list[str], **kwargs): +# name = naming[0] +# self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] +# self.transforms = { +# self.naming[0]: ( +# f"{naming}_x", +# lambda params: jnp.sin(params[self.naming[0]]) +# * jnp.cos(params[self.naming[1]]) +# * params[self.naming[2]], +# ), +# self.naming[1]: ( +# f"{naming}_y", +# lambda params: jnp.sin(params[self.naming[0]]) +# * jnp.sin(params[self.naming[1]]) +# * params[self.naming[2]], +# ), +# self.naming[2]: ( +# f"{naming}_z", +# lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], +# ), +# } + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# rng_keys = jax.random.split(rng_key, 3) +# theta = jnp.arccos( +# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) +# ) +# phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) +# mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) +# return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# theta = x[self.naming[0]] +# phi = x[self.naming[1]] +# mag = x[self.naming[2]] +# output = jnp.where( +# (mag > 1) +# | (mag < 0) +# | (phi > 2 * jnp.pi) +# | (phi < 0) +# | (theta > jnp.pi) +# | (theta < 0), +# jnp.zeros_like(0) - jnp.inf, +# jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), +# ) +# return output + + +# @jaxtyped(typechecker=typechecker) +# class AlignedSpin(Prior): +# """ +# Prior distribution for the aligned (z) component of the spin. + +# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] +# with its orientation uniform on a sphere + +# p(chi) = -log(|chi| / amax) / 2 / amax + +# This is useful when comparing results between an aligned-spin run and +# a precessing spin run. + +# See (A7) of https://arxiv.org/abs/1805.10457. +# """ + +# amax: Float = 0.99 +# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) +# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + +# def __repr__(self): +# return f"Alignedspin(amax={self.amax}, naming={self.naming})" + +# def __init__( +# self, +# amax: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" +# self.amax = amax + +# # build the interpolation table for the ppf of the one-sided distribution +# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) +# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax +# self.chi_axis = chi_axis +# self.cdf_vals = cdf_vals + +# @property +# def xmin(self): +# return -self.amax + +# @property +# def xmax(self): +# return self.amax + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from the Alignedspin distribution. + +# for chi > 0; +# p(chi) = -log(chi / amax) / amax # halved normalization constant +# cdf(chi) = -chi * (log(chi / amax) - 1) / amax + +# Since there is a pole at chi=0, we will sample with the following steps +# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise +# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) +# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) +# 3. Map the quantile to chi via the ppf by checking against the table +# built during the initialization +# 4. add back the sign + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# # 1. calculate the sign of chi from the q_samples +# sign_samples = jnp.where( +# q_samples >= 0.5, +# jnp.zeros_like(q_samples) + 1.0, +# jnp.zeros_like(q_samples) - 1.0, +# ) +# # 2. remap q_samples +# q_samples = jnp.where( +# q_samples >= 0.5, +# 2 * (q_samples - 0.5), +# 2 * (0.5 - q_samples), +# ) +# # 3. map the quantile to chi via interpolation +# samples = jnp.interp( +# q_samples, +# self.cdf_vals, +# self.chi_axis, +# ) +# # 4. add back the sign +# samples *= sign_samples + +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_p = jnp.where( +# (variable >= self.amax) | (variable <= -self.amax), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), +# ) +# return log_p + + +# @jaxtyped(typechecker=typechecker) +# class EarthFrame(Prior): +# """ +# Prior distribution for sky location in Earth frame. +# """ + +# ifos: list = field(default_factory=list) +# gmst: float = 0.0 +# delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) + +# def __repr__(self): +# return f"EarthFrame(naming={self.naming})" + +# def __init__(self, gps: Float, ifos: list, **kwargs): +# self.naming = ["zenith", "azimuth"] +# if len(ifos) < 2: +# return ValueError( +# "At least two detectors are needed to define the Earth frame" +# ) +# elif isinstance(ifos[0], str): +# self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] +# elif isinstance(ifos[0], GroundBased2G): +# self.ifos = ifos[:1] +# else: +# return ValueError( +# "ifos should be a list of detector names or GroundBased2G objects" +# ) +# self.gmst = float( +# Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad +# ) +# self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex + +# self.transforms = { +# "azimuth": ( +# "ra", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[0], +# ), +# "zenith": ( +# "dec", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[1], +# ), +# } + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# rng_keys = jax.random.split(rng_key, 2) +# zenith = jnp.arccos( +# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) +# ) +# azimuth = jax.random.uniform( +# rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi +# ) +# return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# zenith = x["zenith"] +# azimuth = x["azimuth"] +# output = jnp.where( +# (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), +# jnp.zeros_like(0) - jnp.inf, +# jnp.zeros_like(0), +# ) +# return output + jnp.log(jnp.sin(zenith)) + + +# @jaxtyped(typechecker=typechecker) +# class PowerLaw(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ x^{\alpha} +# """ + +# xmin: float = 0.0 +# xmax: float = 1.0 +# alpha: float = 0.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: float, +# xmax: float, +# alpha: Union[Int, float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin > 0.0, "With negative alpha, xmin must > 0" +# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha +# if alpha == -1: +# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) +# else: +# self.normalization = (1 + self.alpha) / ( +# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a power-law distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# if self.alpha == -1: +# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) +# else: +# samples = ( +# self.xmin ** (1.0 + self.alpha) +# + q_samples +# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) +# return log_p + log_in_range + + +# @jaxtyped(typechecker=typechecker) +# class Exponential(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ exp(\alpha x) +# """ + +# xmin: float = 0.0 +# xmax: float = jnp.inf +# alpha: float = -1.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: Float, +# xmax: Float, +# alpha: Union[Int, Float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin != -jnp.inf, "With negative alpha, xmin must finite" +# if alpha > 0.0: +# assert xmax != jnp.inf, "With positive alpha, xmax must finite" +# assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" +# assert self.n_dim == 1, "Exponential needs to be 1D distributions" + +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha + +# self.normalization = self.alpha / ( +# jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a exponential distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# samples = ( +# self.xmin +# + jnp.log1p( +# q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) +# ) +# / self.alpha +# ) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * variable + jnp.log(self.normalization) +# return log_p + log_in_range + + +# @jaxtyped(typechecker=typechecker) +# class Normal(Prior): +# mean: Float = 0.0 +# std: Float = 1.0 + +# def __repr__(self): +# return f"Normal(mean={self.mean}, std={self.std})" + +# def __init__( +# self, +# mean: Float, +# std: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Normal needs to be 1D distributions" +# self.mean = mean +# self.std = std + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a normal distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# samples = jax.random.normal(rng_key, (n_samples,)) +# samples = self.mean + samples * self.std +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Array]) -> Float: +# variable = x[self.naming[0]] +# output = ( +# -0.5 * jnp.log(2 * jnp.pi) +# - jnp.log(self.std) +# - 0.5 * ((variable - self.mean) / self.std) ** 2 +# ) +# return output diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 55a5d0ce..9e06d6ab 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from dataclasses import field -from typing import Callable, Union +from typing import Callable import jax import jax.numpy as jnp -from beartype import beartype as typechecker from chex import assert_rank -from jaxtyping import Array, Float, jaxtyped +from jaxtyping import Float, Array + class Transform(ABC): """ @@ -16,7 +15,8 @@ class Transform(ABC): """ name_mapping: tuple[list[str], list[str]] - transform_func: Callable[[dict[str, Float]], dict[str, Float]] + transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -92,7 +92,8 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x - + + class Scale(UnivariateTransform): scale: Float @@ -105,6 +106,7 @@ def __init__( self.scale = scale self.transform_func = lambda x: x * self.scale + class Offset(UnivariateTransform): offset: Float @@ -117,6 +119,7 @@ def __init__( self.offset = offset self.transform_func = lambda x: x + self.offset + class Logit(UnivariateTransform): """ Logit transform following @@ -135,10 +138,11 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + class ArcSine(UnivariateTransform): """ ArcSine transformation - + Parameters ---------- name_mapping : tuple[list[str], list[str]] From 194c565722beb10fd4a5c41934434583e7de9908 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 12:52:28 -0400 Subject: [PATCH 035/172] Revert "Update transforms.py" This reverts commit b9a725506012f80db3d297fe66f8e73fc632d37e. --- src/jimgw/transforms.py | 55 ----------------------------------------- 1 file changed, 55 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 2feac5a2..d8ee9536 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -175,58 +175,3 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arccos(x) - - -class PowerLawTransform(UnivariateTransform): - """ - PowerLaw transformation - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - xmin: Float - xmax: Float - alpha: Float - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - alpha: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.alpha = alpha - self.transform_func = lambda x: (self.xmin ** (1.0 + self.alpha)+ x* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)))** (1.0 / (1.0 + self.alpha)), - - - -class ParetoTransform(UnivariateTransform): - """ - Pareto transformation: Power law when alpha = -1 - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax / self.xmin) - ) From 5807f3c110564e5215344a5e00755027b8535b59 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 12:52:35 -0400 Subject: [PATCH 036/172] Revert "Update prior.py to include powerLaw" This reverts commit d761f6bfb83da27d83813f4cb1fce44a2882ebac. --- src/jimgw/prior.py | 191 ++++++++++++++++++--------------------------- 1 file changed, 78 insertions(+), 113 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 637bbc2b..de378973 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine, PowerLawTransform, ParetoTransform +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine class Prior(Distribution): @@ -311,41 +311,6 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) -@jaxtyped(typechecker=typechecker) -class PowerLaw(SequentialTransform): - xmin: float - xmax: float - alpha: float - - def __repr__(self): - return f"PowerLaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" - - def __init__( - self, - xmin: float, - xmax: float, - alpha: float, - parameter_names: list[str], - ): - self.parameter_names = parameter_names - assert self.n_dim == 1, "Power law needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - if self.alpha == -1.0: - transform = ParetoTransform((self.parameter_names, self.parameter_names), xmin, xmax) - else: - transform = PowerLawTransform( - (self.parameter_names, self.parameter_names), xmin, xmax, alpha - ) - super().__init__( - LogisticDistribution(self.parameter_names), - [ - Logit((self.parameter_names, self.parameter_names)), - transform, - ], - ) - # ====================== Things below may need rework ====================== @@ -539,83 +504,83 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output + jnp.log(jnp.sin(zenith)) -# @jaxtyped(typechecker=typechecker) -# class PowerLaw(Prior): -# """ -# A prior following the power-law with alpha in the range [xmin, xmax). -# p(x) ~ x^{\alpha} -# """ - -# xmin: float = 0.0 -# xmax: float = 1.0 -# alpha: float = 0.0 -# normalization: float = 1.0 - -# def __repr__(self): -# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - -# def __init__( -# self, -# xmin: float, -# xmax: float, -# alpha: Union[Int, float], -# naming: list[str], -# transforms: dict[str, tuple[str, Callable]] = {}, -# **kwargs, -# ): -# super().__init__(naming, transforms) -# if alpha < 0.0: -# assert xmin > 0.0, "With negative alpha, xmin must > 0" -# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" -# self.xmax = xmax -# self.xmin = xmin -# self.alpha = alpha -# if alpha == -1: -# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) -# else: -# self.normalization = (1 + self.alpha) / ( -# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) -# ) - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# """ -# Sample from a power-law distribution. - -# Parameters -# ---------- -# rng_key : PRNGKeyArray -# A random key to use for sampling. -# n_samples : int -# The number of samples to draw. - -# Returns -# ------- -# samples : dict -# Samples from the distribution. The keys are the names of the parameters. - -# """ -# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) -# if self.alpha == -1: -# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) -# else: -# samples = ( -# self.xmin ** (1.0 + self.alpha) -# + q_samples -# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) -# ) ** (1.0 / (1.0 + self.alpha)) -# return self.add_name(samples[None]) - -# def log_prob(self, x: dict[str, Float]) -> Float: -# variable = x[self.naming[0]] -# log_in_range = jnp.where( -# (variable >= self.xmax) | (variable <= self.xmin), -# jnp.zeros_like(variable) - jnp.inf, -# jnp.zeros_like(variable), -# ) -# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) -# return log_p + log_in_range +@jaxtyped(typechecker=typechecker) +class PowerLaw(Prior): + """ + A prior following the power-law with alpha in the range [xmin, xmax). + p(x) ~ x^{\alpha} + """ + + xmin: float = 0.0 + xmax: float = 1.0 + alpha: float = 0.0 + normalization: float = 1.0 + + def __repr__(self): + return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + + def __init__( + self, + xmin: float, + xmax: float, + alpha: Union[Int, float], + naming: list[str], + transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, + ): + super().__init__(naming, transforms) + if alpha < 0.0: + assert xmin > 0.0, "With negative alpha, xmin must > 0" + assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + if alpha == -1: + self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) + else: + self.normalization = (1 + self.alpha) / ( + self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) + ) + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a power-law distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) + if self.alpha == -1: + samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) + else: + samples = ( + self.xmin ** (1.0 + self.alpha) + + q_samples + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) ** (1.0 / (1.0 + self.alpha)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Float]) -> Float: + variable = x[self.naming[0]] + log_in_range = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) + log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) + return log_p + log_in_range @jaxtyped(typechecker=typechecker) From 8256d0ee94c4d787d0257280d01461ce5127a8cd Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 13:00:09 -0400 Subject: [PATCH 037/172] Comment out old prior --- src/jimgw/prior.py | 692 ++++++++++++++++++++++----------------------- 1 file changed, 346 insertions(+), 346 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index dc983a50..499bd1c2 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -313,349 +313,349 @@ def __init__(self, parameter_names: list[str], **kwargs): # ====================== Things below may need rework ====================== -@jaxtyped(typechecker=typechecker) -class AlignedSpin(Prior): - """ - Prior distribution for the aligned (z) component of the spin. - - This assume the prior distribution on the spin magnitude to be uniform in [0, amax] - with its orientation uniform on a sphere - - p(chi) = -log(|chi| / amax) / 2 / amax - - This is useful when comparing results between an aligned-spin run and - a precessing spin run. - - See (A7) of https://arxiv.org/abs/1805.10457. - """ - - amax: Float = 0.99 - chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - - def __repr__(self): - return f"Alignedspin(amax={self.amax}, naming={self.naming})" - - def __init__( - self, - amax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" - self.amax = amax - - # build the interpolation table for the ppf of the one-sided distribution - chi_axis = jnp.linspace(1e-31, self.amax, num=1000) - cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax - self.chi_axis = chi_axis - self.cdf_vals = cdf_vals - - @property - def xmin(self): - return -self.amax - - @property - def xmax(self): - return self.amax - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from the Alignedspin distribution. - - for chi > 0; - p(chi) = -log(chi / amax) / amax # halved normalization constant - cdf(chi) = -chi * (log(chi / amax) - 1) / amax - - Since there is a pole at chi=0, we will sample with the following steps - 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise - 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) - 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) - 3. Map the quantile to chi via the ppf by checking against the table - built during the initialization - 4. add back the sign - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - # 1. calculate the sign of chi from the q_samples - sign_samples = jnp.where( - q_samples >= 0.5, - jnp.zeros_like(q_samples) + 1.0, - jnp.zeros_like(q_samples) - 1.0, - ) - # 2. remap q_samples - q_samples = jnp.where( - q_samples >= 0.5, - 2 * (q_samples - 0.5), - 2 * (0.5 - q_samples), - ) - # 3. map the quantile to chi via interpolation - samples = jnp.interp( - q_samples, - self.cdf_vals, - self.chi_axis, - ) - # 4. add back the sign - samples *= sign_samples - - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_p = jnp.where( - (variable >= self.amax) | (variable <= -self.amax), - jnp.zeros_like(variable) - jnp.inf, - jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), - ) - return log_p - - -@jaxtyped(typechecker=typechecker) -class EarthFrame(Prior): - """ - Prior distribution for sky location in Earth frame. - """ - - ifos: list = field(default_factory=list) - gmst: float = 0.0 - delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) - - def __repr__(self): - return f"EarthFrame(naming={self.naming})" - - def __init__(self, gps: Float, ifos: list, **kwargs): - self.naming = ["zenith", "azimuth"] - if len(ifos) < 2: - return ValueError( - "At least two detectors are needed to define the Earth frame" - ) - elif isinstance(ifos[0], str): - self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] - elif isinstance(ifos[0], GroundBased2G): - self.ifos = ifos[:1] - else: - return ValueError( - "ifos should be a list of detector names or GroundBased2G objects" - ) - self.gmst = float( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad - ) - self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - - self.transforms = { - "azimuth": ( - "ra", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[0], - ), - "zenith": ( - "dec", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[1], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 2) - zenith = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - azimuth = jax.random.uniform( - rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi - ) - return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - zenith = x["zenith"] - azimuth = x["azimuth"] - output = jnp.where( - (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.zeros_like(0), - ) - return output + jnp.log(jnp.sin(zenith)) - - -@jaxtyped(typechecker=typechecker) -class PowerLaw(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ x^{\alpha} - """ - - xmin: float = 0.0 - xmax: float = 1.0 - alpha: float = 0.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: float, - xmax: float, - alpha: Union[Int, float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin > 0.0, "With negative alpha, xmin must > 0" - assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - if alpha == -1: - self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) - else: - self.normalization = (1 + self.alpha) / ( - self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a power-law distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - if self.alpha == -1: - samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) - else: - samples = ( - self.xmin ** (1.0 + self.alpha) - + q_samples - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) - return log_p + log_in_range - - -@jaxtyped(typechecker=typechecker) -class Exponential(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ exp(\alpha x) - """ - - xmin: float = 0.0 - xmax: float = jnp.inf - alpha: float = -1.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: Float, - xmax: Float, - alpha: Union[Int, Float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin != -jnp.inf, "With negative alpha, xmin must finite" - if alpha > 0.0: - assert xmax != jnp.inf, "With positive alpha, xmax must finite" - assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" - assert self.n_dim == 1, "Exponential needs to be 1D distributions" - - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - - self.normalization = self.alpha / ( - jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a exponential distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - samples = ( - self.xmin - + jnp.log1p( - q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) - ) - / self.alpha - ) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * variable + jnp.log(self.normalization) - return log_p + log_in_range +# @jaxtyped(typechecker=typechecker) +# class AlignedSpin(Prior): +# """ +# Prior distribution for the aligned (z) component of the spin. + +# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] +# with its orientation uniform on a sphere + +# p(chi) = -log(|chi| / amax) / 2 / amax + +# This is useful when comparing results between an aligned-spin run and +# a precessing spin run. + +# See (A7) of https://arxiv.org/abs/1805.10457. +# """ + +# amax: Float = 0.99 +# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) +# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + +# def __repr__(self): +# return f"Alignedspin(amax={self.amax}, naming={self.naming})" + +# def __init__( +# self, +# amax: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" +# self.amax = amax + +# # build the interpolation table for the ppf of the one-sided distribution +# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) +# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax +# self.chi_axis = chi_axis +# self.cdf_vals = cdf_vals + +# @property +# def xmin(self): +# return -self.amax + +# @property +# def xmax(self): +# return self.amax + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from the Alignedspin distribution. + +# for chi > 0; +# p(chi) = -log(chi / amax) / amax # halved normalization constant +# cdf(chi) = -chi * (log(chi / amax) - 1) / amax + +# Since there is a pole at chi=0, we will sample with the following steps +# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise +# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) +# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) +# 3. Map the quantile to chi via the ppf by checking against the table +# built during the initialization +# 4. add back the sign + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# # 1. calculate the sign of chi from the q_samples +# sign_samples = jnp.where( +# q_samples >= 0.5, +# jnp.zeros_like(q_samples) + 1.0, +# jnp.zeros_like(q_samples) - 1.0, +# ) +# # 2. remap q_samples +# q_samples = jnp.where( +# q_samples >= 0.5, +# 2 * (q_samples - 0.5), +# 2 * (0.5 - q_samples), +# ) +# # 3. map the quantile to chi via interpolation +# samples = jnp.interp( +# q_samples, +# self.cdf_vals, +# self.chi_axis, +# ) +# # 4. add back the sign +# samples *= sign_samples + +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_p = jnp.where( +# (variable >= self.amax) | (variable <= -self.amax), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), +# ) +# return log_p + + +# @jaxtyped(typechecker=typechecker) +# class EarthFrame(Prior): +# """ +# Prior distribution for sky location in Earth frame. +# """ + +# ifos: list = field(default_factory=list) +# gmst: float = 0.0 +# delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) + +# def __repr__(self): +# return f"EarthFrame(naming={self.naming})" + +# def __init__(self, gps: Float, ifos: list, **kwargs): +# self.naming = ["zenith", "azimuth"] +# if len(ifos) < 2: +# return ValueError( +# "At least two detectors are needed to define the Earth frame" +# ) +# elif isinstance(ifos[0], str): +# self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] +# elif isinstance(ifos[0], GroundBased2G): +# self.ifos = ifos[:1] +# else: +# return ValueError( +# "ifos should be a list of detector names or GroundBased2G objects" +# ) +# self.gmst = float( +# Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad +# ) +# self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex + +# self.transforms = { +# "azimuth": ( +# "ra", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[0], +# ), +# "zenith": ( +# "dec", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[1], +# ), +# } + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# rng_keys = jax.random.split(rng_key, 2) +# zenith = jnp.arccos( +# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) +# ) +# azimuth = jax.random.uniform( +# rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi +# ) +# return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# zenith = x["zenith"] +# azimuth = x["azimuth"] +# output = jnp.where( +# (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), +# jnp.zeros_like(0) - jnp.inf, +# jnp.zeros_like(0), +# ) +# return output + jnp.log(jnp.sin(zenith)) + + +# @jaxtyped(typechecker=typechecker) +# class PowerLaw(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ x^{\alpha} +# """ + +# xmin: float = 0.0 +# xmax: float = 1.0 +# alpha: float = 0.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: float, +# xmax: float, +# alpha: Union[Int, float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin > 0.0, "With negative alpha, xmin must > 0" +# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha +# if alpha == -1: +# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) +# else: +# self.normalization = (1 + self.alpha) / ( +# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a power-law distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# if self.alpha == -1: +# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) +# else: +# samples = ( +# self.xmin ** (1.0 + self.alpha) +# + q_samples +# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) +# return log_p + log_in_range + + +# @jaxtyped(typechecker=typechecker) +# class Exponential(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ exp(\alpha x) +# """ + +# xmin: float = 0.0 +# xmax: float = jnp.inf +# alpha: float = -1.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: Float, +# xmax: Float, +# alpha: Union[Int, Float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin != -jnp.inf, "With negative alpha, xmin must finite" +# if alpha > 0.0: +# assert xmax != jnp.inf, "With positive alpha, xmax must finite" +# assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" +# assert self.n_dim == 1, "Exponential needs to be 1D distributions" + +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha + +# self.normalization = self.alpha / ( +# jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a exponential distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# samples = ( +# self.xmin +# + jnp.log1p( +# q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) +# ) +# / self.alpha +# ) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * variable + jnp.log(self.normalization) +# return log_p + log_in_range From b42b0f4622b2a457ae759eabbbbc6c6641af8da4 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 13:01:07 -0400 Subject: [PATCH 038/172] Reformat --- src/jimgw/prior.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 499bd1c2..5e48cb8a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -6,8 +6,6 @@ from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped -from jimgw.single_event.detector import GroundBased2G, detector_preset -from jimgw.single_event.utils import zenith_azimuth_to_ra_dec from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine From a9629b25e984c104375a2710fd2c8e496faaa5b8 Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 26 Jul 2024 13:09:46 -0400 Subject: [PATCH 039/172] Update Prior naming --- src/jimgw/prior.py | 34 +++++++++++++++++----------------- test/test_prior.py | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 5e48cb8a..0f8b70b1 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -132,7 +132,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) -class SequentialTransform(Prior): +class SequentialTransformPrior(Prior): """ Transform a prior distribution by applying a sequence of transforms. """ @@ -182,7 +182,7 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]: return x -class Combine(Prior): +class CombinePrior(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. This assumes the priors composing the Combine class are independent. @@ -220,12 +220,12 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped(typechecker=typechecker) -class Uniform(SequentialTransform): +class UniformPrior(SequentialTransformPrior): xmin: float xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" + return f"UniformPrior(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" def __init__( self, @@ -234,7 +234,7 @@ def __init__( parameter_names: list[str], ): self.parameter_names = parameter_names - assert self.n_dim == 1, "Uniform needs to be 1D distributions" + assert self.n_dim == 1, "UniformPrior needs to be 1D distributions" self.xmax = xmax self.xmin = xmin super().__init__( @@ -248,43 +248,43 @@ def __init__( @jaxtyped(typechecker=typechecker) -class Sine(SequentialTransform): +class SinePrior(SequentialTransformPrior): """ A prior distribution where the pdf is proportional to sin(x) in the range [0, pi]. """ def __repr__(self): - return f"Sine(parameter_names={self.parameter_names})" + return f"SinePrior(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names - assert self.n_dim == 1, "Sine needs to be 1D distributions" + assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( - Uniform(-1.0, 1.0, f"cos_{self.parameter_names}"), + UniformPrior(-1.0, 1.0, f"cos_{self.parameter_names}"), [ArcCosine(([f"cos_{self.parameter_names}"], [self.parameter_names]))], ) @jaxtyped(typechecker=typechecker) -class Cosine(SequentialTransform): +class CosinePrior(SequentialTransformPrior): """ A prior distribution where the pdf is proportional to cos(x) in the range [-pi/2, pi/2]. """ def __repr__(self): - return f"Cosine(parameter_names={self.parameter_names})" + return f"CosinePrior(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names - assert self.n_dim == 1, "Cosine needs to be 1D distributions" + assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( - Uniform(-1.0, 1.0, f"sin_{self.parameter_names}"), + UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names}"), [ArcSine(([f"sin_{self.parameter_names}"], [self.parameter_names]))], ) @jaxtyped(typechecker=typechecker) -class UniformSphere(Combine): +class UniformSpherePrior(CombinePrior): def __repr__(self): return f"UniformSphere(parameter_names={self.parameter_names})" @@ -301,9 +301,9 @@ def __init__(self, parameter_names: list[str], **kwargs): ] super().__init__( [ - Uniform(0.0, 1.0, [self.parameter_names[0]]), - Sine([self.parameter_names[1]]), - Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + UniformPrior(0.0, 1.0, [self.parameter_names[0]]), + SinePrior([self.parameter_names[1]]), + UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] ) diff --git a/test/test_prior.py b/test/test_prior.py index b6eb7c87..98803075 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -6,7 +6,7 @@ def test_logistic(self): p = Logit() def test_uniform(self): - p = Uniform(0.0, 10.0, ['x']) + p = UniformPrior(0.0, 10.0, ['x']) samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, -jnp.log(10.0)) From fbba5af599f87fae064e12d6d593398d48b3f8ce Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 26 Jul 2024 13:12:13 -0400 Subject: [PATCH 040/172] Standize Transform naming --- src/jimgw/prior.py | 29 +++++++++++++++++++++++------ src/jimgw/transforms.py | 10 +++++----- test/test_prior.py | 11 ++++++----- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 0f8b70b1..3d26eb04 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -6,7 +6,14 @@ from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine +from jimgw.transforms import ( + Transform, + LogitTransform, + ScaleTransform, + OffsetTransform, + ArcSineTransform, + ArcCosineTransform, +) class Prior(Distribution): @@ -240,9 +247,11 @@ def __init__( super().__init__( LogisticDistribution(self.parameter_names), [ - Logit((self.parameter_names, self.parameter_names)), - Scale((self.parameter_names, self.parameter_names), xmax - xmin), - Offset((self.parameter_names, self.parameter_names), xmin), + LogitTransform((self.parameter_names, self.parameter_names)), + ScaleTransform( + (self.parameter_names, self.parameter_names), xmax - xmin + ), + OffsetTransform((self.parameter_names, self.parameter_names), xmin), ], ) @@ -261,7 +270,11 @@ def __init__(self, parameter_names: list[str]): assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( UniformPrior(-1.0, 1.0, f"cos_{self.parameter_names}"), - [ArcCosine(([f"cos_{self.parameter_names}"], [self.parameter_names]))], + [ + ArcCosineTransform( + ([f"cos_{self.parameter_names}"], [self.parameter_names]) + ) + ], ) @@ -279,7 +292,11 @@ def __init__(self, parameter_names: list[str]): assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names}"), - [ArcSine(([f"sin_{self.parameter_names}"], [self.parameter_names]))], + [ + ArcSineTransform( + ([f"sin_{self.parameter_names}"], [self.parameter_names]) + ) + ], ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8ee9536..5022fc88 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -94,7 +94,7 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: return x -class Scale(UnivariateTransform): +class ScaleTransform(UnivariateTransform): scale: Float def __init__( @@ -107,7 +107,7 @@ def __init__( self.transform_func = lambda x: x * self.scale -class Offset(UnivariateTransform): +class OffsetTransform(UnivariateTransform): offset: Float def __init__( @@ -120,7 +120,7 @@ def __init__( self.transform_func = lambda x: x + self.offset -class Logit(UnivariateTransform): +class LogitTransform(UnivariateTransform): """ Logit transform following @@ -139,7 +139,7 @@ def __init__( self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) -class ArcSine(UnivariateTransform): +class ArcSineTransform(UnivariateTransform): """ ArcSine transformation @@ -158,7 +158,7 @@ def __init__( self.transform_func = lambda x: jnp.arcsin(x) -class ArcCosine(UnivariateTransform): +class ArcCosineTransform(UnivariateTransform): """ ArcCosine transformation diff --git a/test/test_prior.py b/test/test_prior.py index 98803075..6431d7d0 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -3,20 +3,21 @@ class TestUnivariatePrior: def test_logistic(self): - p = Logit() + p = LogitTransform() def test_uniform(self): - p = UniformPrior(0.0, 10.0, ['x']) + p = UniformPrior(0.0, 10.0, ["x"]) samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, -jnp.log(10.0)) + class TestPriorOperations: def test_combine(self): raise NotImplementedError - + def test_sequence(self): raise NotImplementedError - + def test_factor(self): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError From 4a5d932aea617a36d2e7e26a7bf61ff0007953d6 Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 26 Jul 2024 15:13:55 -0400 Subject: [PATCH 041/172] Fixing naming problem and add base_distribution tracer --- src/jimgw/prior.py | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3d26eb04..dc6a492b 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -27,6 +27,7 @@ class Prior(Distribution): """ parameter_names: list[str] + composite: bool = False @property def n_dim(self): @@ -61,7 +62,6 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: raise NotImplementedError - @jaxtyped(typechecker=typechecker) class LogisticDistribution(Prior): @@ -70,6 +70,7 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) + self.composite = False assert self.n_dim == 1, "LogisticDistribution needs to be 1D distributions" def sample( @@ -108,6 +109,7 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) + self.composite = False assert ( self.n_dim == 1 ), "StandardNormalDistribution needs to be 1D distributions" @@ -161,11 +163,12 @@ def __init__( self.parameter_names = base_prior.parameter_names for transform in transforms: self.parameter_names = transform.propagate_name(self.parameter_names) + self.composite = True def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - output = self.sample_base(rng_key, n_samples) + output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) def log_prob(self, x: dict[str, Float]) -> Float: @@ -178,11 +181,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: output -= log_jacobian return output - def sample_base( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self.base_prior.sample(rng_key, n_samples) - def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) @@ -195,10 +193,12 @@ class CombinePrior(Prior): This assumes the priors composing the Combine class are independent. """ - priors: list[Prior] = field(default_factory=list) + base_prior: list[Prior] = field(default_factory=list) def __repr__(self): - return f"Combine(priors={self.priors}, parameter_names={self.parameter_names})" + return ( + f"Combine(priors={self.base_prior}, parameter_names={self.parameter_names})" + ) def __init__( self, @@ -207,21 +207,22 @@ def __init__( parameter_names = [] for prior in priors: parameter_names += prior.parameter_names - self.priors = priors + self.base_prior = priors self.parameter_names = parameter_names + self.composite = True def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: output = {} - for prior in self.priors: + for prior in self.base_prior: rng_key, subkey = jax.random.split(rng_key) output.update(prior.sample(subkey, n_samples)) return output def log_prob(self, x: dict[str, Float]) -> Float: output = 0.0 - for prior in self.priors: + for prior in self.base_prior: output += prior.log_prob(x) return output @@ -269,10 +270,10 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, f"cos_{self.parameter_names}"), + UniformPrior(-1.0, 1.0, [f"cos_{self.parameter_names[0]}"]), [ ArcCosineTransform( - ([f"cos_{self.parameter_names}"], [self.parameter_names]) + ([f"cos_{self.parameter_names[0]}"], [self.parameter_names[0]]) ) ], ) @@ -291,10 +292,10 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names}"), + UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names[0]}"), [ ArcSineTransform( - ([f"sin_{self.parameter_names}"], [self.parameter_names]) + ([f"sin_{self.parameter_names[0]}"], [self.parameter_names[0]]) ) ], ) @@ -324,6 +325,17 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output # ====================== Things below may need rework ====================== From 27dc87090ff3ed29cb52792f957d52de3595a6c8 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:37:19 -0400 Subject: [PATCH 042/172] Updated powerlaw --- src/jimgw/prior.py | 43 +++++++++++++++++++++++++++++++ src/jimgw/transforms.py | 57 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index dc6a492b..16c28b8e 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -13,6 +13,8 @@ OffsetTransform, ArcSineTransform, ArcCosineTransform, + PowerLawTransform, + ParetoTransform, ) @@ -62,6 +64,7 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: raise NotImplementedError + @jaxtyped(typechecker=typechecker) class LogisticDistribution(Prior): @@ -325,6 +328,45 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) + +@jaxtyped(typechecker=typechecker) +class PowerLawPrior(SequentialTransformPrior): + xmin: float + xmax: float + alpha: float + + def __repr__(self): + return f"PowerLaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" + + def __init__( + self, + xmin: float, + xmax: float, + alpha: float, + parameter_names: list[str], + ): + self.parameter_names = parameter_names + assert self.n_dim == 1, "Power law needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + if self.alpha == -1.0: + transform = ParetoTransform( + (self.parameter_names, self.parameter_names), xmin, xmax + ) + else: + transform = PowerLawTransform( + (self.parameter_names, self.parameter_names), xmin, xmax, alpha + ) + super().__init__( + LogisticDistribution(self.parameter_names), + [ + Logit((self.parameter_names, self.parameter_names)), + transform, + ], + ) + + def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: if prior.composite: if isinstance(prior.base_prior, list): @@ -337,6 +379,7 @@ def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: return output + # ====================== Things below may need rework ====================== diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 5022fc88..527ea7b8 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -175,3 +175,60 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arccos(x) + + +class PowerLawTransform(UnivariateTransform): + """ + PowerLaw transformation + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + xmin: Float + xmax: Float + alpha: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + alpha: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.alpha = alpha + self.transform_func = ( + lambda x: ( + self.xmin ** (1.0 + self.alpha) + + x + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + ** (1.0 / (1.0 + self.alpha)), + ) + + +class ParetoTransform(UnivariateTransform): + """ + Pareto transformation: Power law when alpha = -1 + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.transform_func = lambda x: self.xmin * jnp.exp( + x * jnp.log(self.xmax / self.xmin) + ) From f7b883d6d0bff8f5be1eaddfbd042e601b1b620a Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:38:28 -0400 Subject: [PATCH 043/172] Updated powerlaw --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 16c28b8e..9f52fdbb 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -361,7 +361,7 @@ def __init__( super().__init__( LogisticDistribution(self.parameter_names), [ - Logit((self.parameter_names, self.parameter_names)), + LogitTransform((self.parameter_names, self.parameter_names)), transform, ], ) From 1665b1cfbc0400647f1ccabb2ed79479c4f32c43 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:26:20 -0400 Subject: [PATCH 044/172] Updated powerlaw --- src/jimgw/test.py | 70 +++++++++++++++++++++++++++++++++++++++++ src/jimgw/transforms.py | 12 +++---- 2 files changed, 74 insertions(+), 8 deletions(-) create mode 100644 src/jimgw/test.py diff --git a/src/jimgw/test.py b/src/jimgw/test.py new file mode 100644 index 00000000..c114b190 --- /dev/null +++ b/src/jimgw/test.py @@ -0,0 +1,70 @@ +import jax +import jax.numpy as jnp +from jimgw.prior import PowerLawPrior + +alpha = -1.0 +xmin = 1.0 +xmax = 10.0 + + +q_samples = jax.random.uniform(jax.random.PRNGKey(42), (100,), minval=0.0, maxval=1.0) +if alpha == -1: + samples = xmin * jnp.exp(q_samples * jnp.log(xmax / xmin)) +else: + samples = ( + xmin ** (1.0 + alpha) + + q_samples * (xmax ** (1.0 + alpha) - xmin ** (1.0 + alpha)) + ) ** (1.0 / (1.0 + alpha)) +samples[None] + +PowerLawPrior(xmin, xmax, alpha, ["x"]).sample(jax.random.PRNGKey(42), 100) + + +############## +if alpha == -1: + normalization = float(1.0 / jnp.log(xmax / xmin)) +else: + normalization = (1 + alpha) / (xmax ** (1 + alpha) - xmin ** (1 + alpha)) +variable = 1.5 +log_in_range = jnp.where( + (variable >= xmax) | (variable <= xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), +) +log_p = alpha * jnp.log(variable) + jnp.log(normalization) +log_p + log_in_range + +PowerLawPrior(xmin, xmax, alpha, ["x"]).log_prob({"x": variable}) + + +# transform_func = lambda x: ( +# xmin ** (1.0 + alpha) + x * (xmax ** (1.0 + alpha) - xmin ** (1.0 + alpha)) +# ) ** (1.0 / (1.0 + alpha)) +# input_params = variable +# assert_rank(input_params, 0) +# output_params = transform_func(input_params) +# jacobian = jax.jacfwd(transform_func)(input_params) +# x[variable] = output_params +# return x, jnp.log(jacobian) + + +import numpy as np + +alpha = -2.0 +xmin = 1.0 +xmax = 20.0 +p = PowerLawPrior(xmin, xmax, alpha, ["x"]) +grid = np.linspace(xmin, xmax, 20) +transform = [] +log_prob = [] +for y in grid: + transform.append(p.transform({"x": y})["x"].item()) + log_prob.append(np.exp(p.log_prob({"x": y}).item())) +import matplotlib.pyplot as plt + +plt.plot(grid, transform) +plt.savefig("transform.png") +plt.close() +plt.plot(transform, log_prob) +plt.savefig("log_prob.png") +plt.close() diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 527ea7b8..da84a059 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -201,14 +201,10 @@ def __init__( self.xmin = xmin self.xmax = xmax self.alpha = alpha - self.transform_func = ( - lambda x: ( - self.xmin ** (1.0 + self.alpha) - + x - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) - ** (1.0 / (1.0 + self.alpha)), - ) + self.transform_func = lambda x: ( + self.xmin ** (1.0 + self.alpha) + + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) ** (1.0 / (1.0 + self.alpha)) class ParetoTransform(UnivariateTransform): From f51847d614d14de03569317e28cc738a9d89d915 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:49:15 -0400 Subject: [PATCH 045/172] Updated powerlaw --- src/jimgw/test.py | 70 ----------------------------------------------- 1 file changed, 70 deletions(-) delete mode 100644 src/jimgw/test.py diff --git a/src/jimgw/test.py b/src/jimgw/test.py deleted file mode 100644 index c114b190..00000000 --- a/src/jimgw/test.py +++ /dev/null @@ -1,70 +0,0 @@ -import jax -import jax.numpy as jnp -from jimgw.prior import PowerLawPrior - -alpha = -1.0 -xmin = 1.0 -xmax = 10.0 - - -q_samples = jax.random.uniform(jax.random.PRNGKey(42), (100,), minval=0.0, maxval=1.0) -if alpha == -1: - samples = xmin * jnp.exp(q_samples * jnp.log(xmax / xmin)) -else: - samples = ( - xmin ** (1.0 + alpha) - + q_samples * (xmax ** (1.0 + alpha) - xmin ** (1.0 + alpha)) - ) ** (1.0 / (1.0 + alpha)) -samples[None] - -PowerLawPrior(xmin, xmax, alpha, ["x"]).sample(jax.random.PRNGKey(42), 100) - - -############## -if alpha == -1: - normalization = float(1.0 / jnp.log(xmax / xmin)) -else: - normalization = (1 + alpha) / (xmax ** (1 + alpha) - xmin ** (1 + alpha)) -variable = 1.5 -log_in_range = jnp.where( - (variable >= xmax) | (variable <= xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), -) -log_p = alpha * jnp.log(variable) + jnp.log(normalization) -log_p + log_in_range - -PowerLawPrior(xmin, xmax, alpha, ["x"]).log_prob({"x": variable}) - - -# transform_func = lambda x: ( -# xmin ** (1.0 + alpha) + x * (xmax ** (1.0 + alpha) - xmin ** (1.0 + alpha)) -# ) ** (1.0 / (1.0 + alpha)) -# input_params = variable -# assert_rank(input_params, 0) -# output_params = transform_func(input_params) -# jacobian = jax.jacfwd(transform_func)(input_params) -# x[variable] = output_params -# return x, jnp.log(jacobian) - - -import numpy as np - -alpha = -2.0 -xmin = 1.0 -xmax = 20.0 -p = PowerLawPrior(xmin, xmax, alpha, ["x"]) -grid = np.linspace(xmin, xmax, 20) -transform = [] -log_prob = [] -for y in grid: - transform.append(p.transform({"x": y})["x"].item()) - log_prob.append(np.exp(p.log_prob({"x": y}).item())) -import matplotlib.pyplot as plt - -plt.plot(grid, transform) -plt.savefig("transform.png") -plt.close() -plt.plot(transform, log_prob) -plt.savefig("log_prob.png") -plt.close() From f7876e6403a2ce90db95bcbd11a3daa335730b0f Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 17:09:42 -0400 Subject: [PATCH 046/172] Updated prior.py to include UniformComponenChirpMassPrior --- src/jimgw/prior.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 9f52fdbb..c182ee16 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -367,6 +367,22 @@ def __init__( ) +@jaxtyped(typechecker=typechecker) +class UniformComponenChirpMassPrior(PowerLawPrior): + """ + A prior in the range [xmin, xmax) for chirp mass which assumes the + component mass to be uniform. + + p(\cal M) ~ \cal M + """ + + def __repr__(self): + return f"UniformComponentChirpMass(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float, parameter_names: list[str]): + super().__init__(xmin, xmax, 1.0, parameter_names) + + def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: if prior.composite: if isinstance(prior.base_prior, list): From 0b2c7cf296b152753b0f757dec7333f97dd668c9 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 21:16:29 -0400 Subject: [PATCH 047/172] Fix priors --- src/jimgw/prior.py | 26 ++++++++++++-------------- src/jimgw/transforms.py | 18 ------------------ 2 files changed, 12 insertions(+), 32 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index dc6a492b..b09e8cc7 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,7 +12,6 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - ArcCosineTransform, ) @@ -30,7 +29,7 @@ class Prior(Distribution): composite: bool = False @property - def n_dim(self): + def n_dim(self) -> int: return len(self.parameter_names) def __init__(self, parameter_names: list[str]): @@ -270,12 +269,13 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, [f"cos_{self.parameter_names[0]}"]), + CosinePrior([f"{self.parameter_names[0]}"]), [ - ArcCosineTransform( - ([f"cos_{self.parameter_names[0]}"], [self.parameter_names[0]]) + OffsetTransform( + (([f"{self.parameter_names[0]}"], [f"{self.parameter_names[0]}"])), + jnp.pi / 2, ) - ], + ] ) @@ -292,7 +292,7 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names[0]}"), + UniformPrior(-1.0, 1.0, [f"sin_{self.parameter_names[0]}"]), [ ArcSineTransform( ([f"sin_{self.parameter_names[0]}"], [self.parameter_names[0]]) @@ -308,14 +308,12 @@ def __repr__(self): return f"UniformSphere(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): - assert ( - len(parameter_names) == 1 - ), "UniformSphere only takes the name of the vector" - parameter_names = parameter_names[0] + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" self.parameter_names = [ - f"{parameter_names}_mag", - f"{parameter_names}_theta", - f"{parameter_names}_phi", + f"{self.parameter_names[0]}_mag", + f"{self.parameter_names[0]}_theta", + f"{self.parameter_names[0]}_phi", ] super().__init__( [ diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 5022fc88..3ea71e9c 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -157,21 +157,3 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) - -class ArcCosineTransform(UnivariateTransform): - """ - ArcCosine transformation - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - self.transform_func = lambda x: jnp.arccos(x) From d917f96da3b764980885fc811b7120b4a9183342 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 21:17:57 -0400 Subject: [PATCH 048/172] Add prior tests --- test/test_prior.py | 69 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 13 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index 6431d7d0..48902b3b 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -1,23 +1,66 @@ from jimgw.prior import * +import scipy.stats as stats class TestUnivariatePrior: def test_logistic(self): - p = LogitTransform() - - def test_uniform(self): - p = UniformPrior(0.0, 10.0, ["x"]) - samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) + p = LogisticDistribution(["x"]) + # Check that the log_prob is finite + samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) - assert jnp.allclose(log_prob, -jnp.log(10.0)) + assert jnp.all(jnp.isfinite(log_prob)) + # Cross-check log_prob with scipy.stats.logistic + x = jnp.linspace(-10.0, 10.0, 1000) + assert jnp.allclose(jax.vmap(p.log_prob)(p.add_name(x[None])), stats.logistic.logpdf(x)) + def test_standard_normal(self): + p = StandardNormalDistribution(["x"]) + # Check that the log_prob is finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Cross-check log_prob with scipy.stats.norm + x = jnp.linspace(-10.0, 10.0, 1000) + assert jnp.allclose(jax.vmap(p.log_prob)(p.add_name(x[None])), stats.norm.logpdf(x)) -class TestPriorOperations: - def test_combine(self): - raise NotImplementedError + def test_uniform(self): + xmin, xmax = -10.0, 10.0 + p = UniformPrior(xmin, xmax, ["x"]) + # Check that the log_prob is correct in the support + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin))) - def test_sequence(self): - raise NotImplementedError + def test_sine(self): + p = SinePrior(["x"]) + # Check that the log_prob is finite + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Check that the log_prob is correct in the support + x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) + y = jax.vmap(p.base_prior.base_prior.transform)(x) + y = jax.vmap(p.base_prior.transform)(y) + y = jax.vmap(p.transform)(y) + assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.sin(y['x'])/2.0)) + + def test_cosine(self): + p = CosinePrior(["x"]) + # Check that the log_prob is finite + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Check that the log_prob is correct in the support + x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) + y = jax.vmap(p.base_prior.transform)(x) + y = jax.vmap(p.transform)(y) + assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.cos(y['x'])/2.0)) - def test_factor(self): - raise NotImplementedError + def test_uniform_sphere(self): + p = UniformSpherePrior(["x"]) + # Check that the log_prob is finite + samples = {} + for i in range(3): + samples.update(trace_prior_parent(p)[i].sample(jax.random.PRNGKey(0), 10000)) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) From 142fe5446db1044a3c6385060a963afbfa65cef3 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:19:25 -0400 Subject: [PATCH 049/172] Set constraint on powerlaw prior input --- src/jimgw/prior.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index c1baca3c..b349d074 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -348,6 +348,8 @@ def __init__( self.xmax = xmax self.xmin = xmin self.alpha = alpha + assert self.xmin < self.xmax, "xmin must be less than xmax" + assert self.xmin > 0.0, "x must be positive" if self.alpha == -1.0: transform = ParetoTransform( (self.parameter_names, self.parameter_names), xmin, xmax From fa6a6e25aa472e0d19233aca66c8daca9e60214e Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:20:23 -0400 Subject: [PATCH 050/172] Added test_power_law --- test/test_prior.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/test_prior.py b/test/test_prior.py index 48902b3b..0f9e4566 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -64,3 +64,33 @@ def test_uniform_sphere(self): samples.update(trace_prior_parent(p)[i].sample(jax.random.PRNGKey(0), 10000)) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) + + def test_power_law(self): + from bilby.core.prior.analytical import PowerLaw + def func(alpha): + xmin = 1.0 + xmax = 20.0 + p = PowerLawPrior(xmin, xmax, alpha, ["x"]) + # Check that all the samples are finite + powerlaw_samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) + + # Check that all the log_probs are finite + samples = trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000)['x'] + base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) + assert jnp.all(jnp.isfinite(log_p)) + + # Check that the log_prob is correct in the support + samples = jnp.linspace(xmin, xmax, 1000) + transformed_samples = jax.vmap(p.transform)({'x': samples})['x'] + assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), PowerLaw(alpha, xmin, xmax).ln_prob(transformed_samples)) + + # Test Pareto Transform + func(-1.0) + # Test other values of alpha + positive_alpha = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] + for alpha_val in positive_alpha: + func(alpha_val) + negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] + for alpha_val in negative_alpha: + func(alpha_val) From a8c06737056cf32ef7347f38af3ceceeabc3818d Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:22:25 -0400 Subject: [PATCH 051/172] Updated ParetoTransform to avoid divide by zero --- src/jimgw/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4c1380cf..48e29a06 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -207,5 +207,5 @@ def __init__( self.xmin = xmin self.xmax = xmax self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax / self.xmin) + x * jnp.log(self.xmax) - x * jnp.log(self.xmin) ) From 6ce5b369ee8cb7eead770412341dda61141b2c3b Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:29:19 -0400 Subject: [PATCH 052/172] Revert update on ParetoTransform --- src/jimgw/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 48e29a06..4c1380cf 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -207,5 +207,5 @@ def __init__( self.xmin = xmin self.xmax = xmax self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax) - x * jnp.log(self.xmin) + x * jnp.log(self.xmax / self.xmin) ) From e9511b9aa64e5dd7851760b49f9dbf0bdfe87061 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:29:49 -0400 Subject: [PATCH 053/172] Updated test_prior.py --- test/test_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prior.py b/test/test_prior.py index 0f9e4566..83392751 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -78,7 +78,7 @@ def func(alpha): # Check that all the log_probs are finite samples = trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000)['x'] base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) - assert jnp.all(jnp.isfinite(log_p)) + assert jnp.all(jnp.isfinite(base_log_p)) # Check that the log_prob is correct in the support samples = jnp.linspace(xmin, xmax, 1000) From 54082eaa7707c115e90ee842f5d791a12aa80bfc Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:00:06 -0400 Subject: [PATCH 054/172] Updated test_prior.py --- test/test_prior.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index 83392751..cee29934 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -68,8 +68,8 @@ def test_uniform_sphere(self): def test_power_law(self): from bilby.core.prior.analytical import PowerLaw def func(alpha): - xmin = 1.0 - xmax = 20.0 + xmin = 0.05 + xmax = 10.0 p = PowerLawPrior(xmin, xmax, alpha, ["x"]) # Check that all the samples are finite powerlaw_samples = p.sample(jax.random.PRNGKey(0), 10000) @@ -81,10 +81,15 @@ def func(alpha): assert jnp.all(jnp.isfinite(base_log_p)) # Check that the log_prob is correct in the support - samples = jnp.linspace(xmin, xmax, 1000) + samples = jnp.linspace(-10.0, 10.0, 1000) transformed_samples = jax.vmap(p.transform)({'x': samples})['x'] - assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), PowerLaw(alpha, xmin, xmax).ln_prob(transformed_samples)) - + # cut off the samples that are outside the support + samples = samples[transformed_samples >= xmin] + transformed_samples = transformed_samples[transformed_samples >= xmin] + samples = samples[transformed_samples <= xmax] + transformed_samples = transformed_samples[transformed_samples <= xmax] + assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), PowerLaw(alpha, xmin, xmax).ln_prob(transformed_samples), atol=1e-4) + # Test Pareto Transform func(-1.0) # Test other values of alpha @@ -93,4 +98,4 @@ def func(alpha): func(alpha_val) negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] for alpha_val in negative_alpha: - func(alpha_val) + func(alpha_val) \ No newline at end of file From 9ad64a1ca0bf0a2c0ea945b5cae3f02c38c07199 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:05:02 -0400 Subject: [PATCH 055/172] Updated test_prior.py --- test/test_prior.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_prior.py b/test/test_prior.py index cee29934..c0106fa9 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -5,6 +5,9 @@ class TestUnivariatePrior: def test_logistic(self): p = LogisticDistribution(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -15,6 +18,9 @@ def test_logistic(self): def test_standard_normal(self): p = StandardNormalDistribution(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -26,6 +32,9 @@ def test_standard_normal(self): def test_uniform(self): xmin, xmax = -10.0, 10.0 p = UniformPrior(xmin, xmax, ["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is correct in the support samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -33,6 +42,9 @@ def test_uniform(self): def test_sine(self): p = SinePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -46,6 +58,9 @@ def test_sine(self): def test_cosine(self): p = CosinePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -58,6 +73,10 @@ def test_cosine(self): def test_uniform_sphere(self): p = UniformSpherePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) + # Check that the log_prob is finite samples = {} for i in range(3): From ce437c417d5d9279f0fed75dce44c2f179edca62 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:12:32 -0400 Subject: [PATCH 056/172] Updated test_prior.py --- test/test_prior.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index c0106fa9..7cd4d8f8 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -75,8 +75,9 @@ def test_uniform_sphere(self): p = UniformSpherePrior(["x"]) # Check that all the samples are finite samples = p.sample(jax.random.PRNGKey(0), 10000) - assert jnp.all(jnp.isfinite(samples['x'])) - + assert jnp.all(jnp.isfinite(samples['x_mag'])) + assert jnp.all(jnp.isfinite(samples['x_theta'])) + assert jnp.all(jnp.isfinite(samples['x_phi'])) # Check that the log_prob is finite samples = {} for i in range(3): From 18d5184837c2adc9e2db7df0c9d5b071ac406cf6 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:17:26 -0400 Subject: [PATCH 057/172] Updated test_prior.py --- test/test_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prior.py b/test/test_prior.py index 7cd4d8f8..cdfdbae0 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -96,7 +96,7 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000)['x'] + samples = (trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000))['x'] base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) assert jnp.all(jnp.isfinite(base_log_p)) From 7c05d14df30b4ba8894967da0f1334ba95419ce2 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:27:02 -0400 Subject: [PATCH 058/172] Updated test_prior.py --- test/test_prior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index cdfdbae0..11206877 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -81,7 +81,7 @@ def test_uniform_sphere(self): # Check that the log_prob is finite samples = {} for i in range(3): - samples.update(trace_prior_parent(p)[i].sample(jax.random.PRNGKey(0), 10000)) + samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000)) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) @@ -96,7 +96,7 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = (trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000))['x'] + samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x'] base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) assert jnp.all(jnp.isfinite(base_log_p)) From f647bf0155a2cb8b6a4cfe460cd71c348e344e31 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 29 Jul 2024 13:31:38 -0400 Subject: [PATCH 059/172] Remove unnecessary test --- test/test_prior.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index 11206877..d9dbe27a 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -5,9 +5,6 @@ class TestUnivariatePrior: def test_logistic(self): p = LogisticDistribution(["x"]) - # Check that all the samples are finite - samples = p.sample(jax.random.PRNGKey(0), 10000) - assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -18,9 +15,6 @@ def test_logistic(self): def test_standard_normal(self): p = StandardNormalDistribution(["x"]) - # Check that all the samples are finite - samples = p.sample(jax.random.PRNGKey(0), 10000) - assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -118,4 +112,4 @@ def func(alpha): func(alpha_val) negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] for alpha_val in negative_alpha: - func(alpha_val) \ No newline at end of file + func(alpha_val) From 6171ce5d7d00935eedb1bebb0067768b1364ca57 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 29 Jul 2024 13:32:24 -0400 Subject: [PATCH 060/172] Reformat --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index b349d074..2d22c368 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -278,7 +278,7 @@ def __init__(self, parameter_names: list[str]): (([f"{self.parameter_names[0]}"], [f"{self.parameter_names[0]}"])), jnp.pi / 2, ) - ] + ], ) From ab7447be5d5ddae1f11768248ef825975a673bda Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 29 Jul 2024 15:15:43 -0400 Subject: [PATCH 061/172] Change naming --- src/jimgw/prior.py | 142 +++++++++++++++------------------------------ 1 file changed, 46 insertions(+), 96 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 2d22c368..3697df2f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -250,11 +250,25 @@ def __init__( super().__init__( LogisticDistribution(self.parameter_names), [ - LogitTransform((self.parameter_names, self.parameter_names)), + LogitTransform( + ( + [f"{self.parameter_names[0]}_base"], + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + ) + ), ScaleTransform( - (self.parameter_names, self.parameter_names), xmax - xmin + ( + ( + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + [f"{self.parameter_names[0]}-({xmin})"], + ), + xmax - xmin, + ), + ), + OffsetTransform( + ([f"{self.parameter_names[0]}-({xmin})"], self.parameter_names), + xmin, ), - OffsetTransform((self.parameter_names, self.parameter_names), xmin), ], ) @@ -275,7 +289,12 @@ def __init__(self, parameter_names: list[str]): CosinePrior([f"{self.parameter_names[0]}"]), [ OffsetTransform( - (([f"{self.parameter_names[0]}"], [f"{self.parameter_names[0]}"])), + ( + ( + [f"{self.parameter_names[0]}-pi/2"], + [f"{self.parameter_names[0]}"], + ) + ), jnp.pi / 2, ) ], @@ -295,10 +314,10 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, [f"sin_{self.parameter_names[0]}"]), + UniformPrior(-1.0, 1.0, [f"sin({self.parameter_names[0]})"]), [ ArcSineTransform( - ([f"sin_{self.parameter_names[0]}"], [self.parameter_names[0]]) + ([f"sin({self.parameter_names[0]})"], [self.parameter_names[0]]) ) ], ) @@ -308,7 +327,7 @@ def __init__(self, parameter_names: list[str]): class UniformSpherePrior(CombinePrior): def __repr__(self): - return f"UniformSphere(parameter_names={self.parameter_names})" + return f"UniformSpherePrior(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): self.parameter_names = parameter_names @@ -334,7 +353,7 @@ class PowerLawPrior(SequentialTransformPrior): alpha: float def __repr__(self): - return f"PowerLaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" + return f"PowerLawPrior(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" def __init__( self, @@ -352,35 +371,45 @@ def __init__( assert self.xmin > 0.0, "x must be positive" if self.alpha == -1.0: transform = ParetoTransform( - (self.parameter_names, self.parameter_names), xmin, xmax + ([f"{self.parameter_names[0]}_before_transform"], self.parameter_names), + xmin, + xmax, ) else: transform = PowerLawTransform( - (self.parameter_names, self.parameter_names), xmin, xmax, alpha + ([f"{self.parameter_names[0]}_before_transform"], self.parameter_names), + xmin, + xmax, + alpha, ) super().__init__( LogisticDistribution(self.parameter_names), [ - LogitTransform((self.parameter_names, self.parameter_names)), + LogitTransform( + ( + [f"{self.parameter_names[0]}_base"], + [f"{self.parameter_names[0]}_before_transform"], + ) + ), transform, ], ) @jaxtyped(typechecker=typechecker) -class UniformComponenChirpMassPrior(PowerLawPrior): +class UniformInComponentsChirpMassPrior(PowerLawPrior): """ A prior in the range [xmin, xmax) for chirp mass which assumes the - component mass to be uniform. + component masses to be uniformly distributed. - p(\cal M) ~ \cal M + p(M_c) ~ M_c """ def __repr__(self): - return f"UniformComponentChirpMass(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" - def __init__(self, xmin: float, xmax: float, parameter_names: list[str]): - super().__init__(xmin, xmax, 1.0, parameter_names) + def __init__(self, xmin: float, xmax: float): + super().__init__(xmin, xmax, 1.0, ["M_c"]) def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: @@ -588,85 +617,6 @@ def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: # return output + jnp.log(jnp.sin(zenith)) -# @jaxtyped(typechecker=typechecker) -# class PowerLaw(Prior): -# """ -# A prior following the power-law with alpha in the range [xmin, xmax). -# p(x) ~ x^{\alpha} -# """ - -# xmin: float = 0.0 -# xmax: float = 1.0 -# alpha: float = 0.0 -# normalization: float = 1.0 - -# def __repr__(self): -# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - -# def __init__( -# self, -# xmin: float, -# xmax: float, -# alpha: Union[Int, float], -# naming: list[str], -# transforms: dict[str, tuple[str, Callable]] = {}, -# **kwargs, -# ): -# super().__init__(naming, transforms) -# if alpha < 0.0: -# assert xmin > 0.0, "With negative alpha, xmin must > 0" -# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" -# self.xmax = xmax -# self.xmin = xmin -# self.alpha = alpha -# if alpha == -1: -# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) -# else: -# self.normalization = (1 + self.alpha) / ( -# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) -# ) - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# """ -# Sample from a power-law distribution. - -# Parameters -# ---------- -# rng_key : PRNGKeyArray -# A random key to use for sampling. -# n_samples : int -# The number of samples to draw. - -# Returns -# ------- -# samples : dict -# Samples from the distribution. The keys are the names of the parameters. - -# """ -# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) -# if self.alpha == -1: -# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) -# else: -# samples = ( -# self.xmin ** (1.0 + self.alpha) -# + q_samples -# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) -# ) ** (1.0 / (1.0 + self.alpha)) -# return self.add_name(samples[None]) - -# def log_prob(self, x: dict[str, Float]) -> Float: -# variable = x[self.naming[0]] -# log_in_range = jnp.where( -# (variable >= self.xmax) | (variable <= self.xmin), -# jnp.zeros_like(variable) - jnp.inf, -# jnp.zeros_like(variable), -# ) -# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) -# return log_p + log_in_range - - # @jaxtyped(typechecker=typechecker) # class Exponential(Prior): # """ From 1b670566f24efa86d2f824e21c4d4079533c856a Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:25:02 -0400 Subject: [PATCH 062/172] Removed bilby --- test/test_prior.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index d9dbe27a..22ae827e 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -80,7 +80,13 @@ def test_uniform_sphere(self): assert jnp.all(jnp.isfinite(log_prob)) def test_power_law(self): - from bilby.core.prior.analytical import PowerLaw + def powerlaw_log_pdf(x, alpha, xmin, xmax): + if alpha == -1.0: + normalization = 1./(jnp.log(xmax) - jnp.log(xmin)) + else: + normalization = (1.0 + alpha) / (xmax**(1.0 + alpha) - xmin**(1.0 + alpha)) + return jnp.log(normalization) + alpha * jnp.log(x) + def func(alpha): xmin = 0.05 xmax = 10.0 @@ -102,11 +108,13 @@ def func(alpha): transformed_samples = transformed_samples[transformed_samples >= xmin] samples = samples[transformed_samples <= xmax] transformed_samples = transformed_samples[transformed_samples <= xmax] - assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), PowerLaw(alpha, xmin, xmax).ln_prob(transformed_samples), atol=1e-4) + # log pdf of powerlaw + assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) # Test Pareto Transform func(-1.0) # Test other values of alpha + print("Testing PowerLawPrior") positive_alpha = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] for alpha_val in positive_alpha: func(alpha_val) From 9d59c4ddd75cd2583795c5403aeae1895833305d Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:00:48 -0400 Subject: [PATCH 063/172] Move unit test to unit directory And create integration directory --- docs/tutorials/naming_system.md | 0 test/integration/test_GW150914.py | 0 test/{ => unit}/test_detector.py | 0 test/{ => unit}/test_prior.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/tutorials/naming_system.md create mode 100644 test/integration/test_GW150914.py rename test/{ => unit}/test_detector.py (100%) rename test/{ => unit}/test_prior.py (100%) diff --git a/docs/tutorials/naming_system.md b/docs/tutorials/naming_system.md new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_detector.py b/test/unit/test_detector.py similarity index 100% rename from test/test_detector.py rename to test/unit/test_detector.py diff --git a/test/test_prior.py b/test/unit/test_prior.py similarity index 100% rename from test/test_prior.py rename to test/unit/test_prior.py From b3110dfd33d6df85cd69e1fa4b2face55979076c Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:27:17 -0400 Subject: [PATCH 064/172] Update priors naming issue add test_GW150914.py --- .github/workflows/python-package.yml | 6 +- src/jimgw/prior.py | 12 ++- test/integration/test_GW150914.py | 116 +++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 12 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 4806efed..1b45e23c 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -3,11 +3,7 @@ name: Python package -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] +on: push, pull_request jobs: build: diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3697df2f..f224b119 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -248,7 +248,7 @@ def __init__( self.xmax = xmax self.xmin = xmin super().__init__( - LogisticDistribution(self.parameter_names), + LogisticDistribution([f"{self.parameter_names[0]}_base"]), [ LogitTransform( ( @@ -258,12 +258,10 @@ def __init__( ), ScaleTransform( ( - ( - [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], - [f"{self.parameter_names[0]}-({xmin})"], - ), - xmax - xmin, + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + [f"{self.parameter_names[0]}-({xmin})"], ), + xmax - xmin, ), OffsetTransform( ([f"{self.parameter_names[0]}-({xmin})"], self.parameter_names), @@ -383,7 +381,7 @@ def __init__( alpha, ) super().__init__( - LogisticDistribution(self.parameter_names), + LogisticDistribution([f"{self.parameter_names[0]}_base"]), [ LogitTransform( ( diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index e69de29b..816647fa 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -0,0 +1,116 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior( + 0.125, + 1.0, + parameter_names=["q"], +) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +cos_iota_prior = UniformPrior( + -1.0, + 1.0, + parameter_names=["cos_iota"], +) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +sin_dec_prior = UniformPrior( + -1.0, + 1.0, + parameter_names=["sin_dec"], +) + +prior = CombinePrior( + [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] +) +likelihood = TransientLikelihoodFD( + [H1, L1], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) From 024c5c3fa6bb9dd69ea5f2dd573046a6d849baac Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:53:00 -0400 Subject: [PATCH 065/172] base parameter names seem working --- src/jimgw/jim.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 08b82b05..1a694172 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from jimgw.base import LikelihoodBase -from jimgw.prior import Prior +from jimgw.prior import Prior, trace_prior_parent class Jim(object): @@ -26,11 +26,18 @@ def __init__( self, likelihood: LikelihoodBase, prior: Prior, - parameter_names: list[str], + parameter_names: list[str] | None = None, **kwargs, ): self.likelihood = likelihood self.prior = prior + if parameter_names is None: + print("No parameter names provided. Will try to trace the prior.") + parents = [] + trace_prior_parent(prior, parents) + parameter_names = [] + for parent in parents: + parameter_names.extend(parent.parameter_names) self.parameter_names = parameter_names seed = kwargs.get("seed", 0) From ac43d54ffb97089513b71e269fcf8bde954d7839 Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:55:44 -0400 Subject: [PATCH 066/172] fix cosine naming error --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f224b119..6a8f10ca 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -284,7 +284,7 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( - CosinePrior([f"{self.parameter_names[0]}"]), + CosinePrior([f"{self.parameter_names[0]}-pi/2"]), [ OffsetTransform( ( From 151c37df64d147bf5b7e5ddfffa9dd29f0e6d221 Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:59:20 -0400 Subject: [PATCH 067/172] prior test now all pass --- .github/workflows/python-package.yml | 2 +- test/unit/test_prior.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 1b45e23c..d70c2ae1 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -3,7 +3,7 @@ name: Python package -on: push, pull_request +on: [push, pull_request] jobs: build: diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 22ae827e..20d71de4 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -96,20 +96,20 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x'] - base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) + samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base'] + base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples}) assert jnp.all(jnp.isfinite(base_log_p)) # Check that the log_prob is correct in the support samples = jnp.linspace(-10.0, 10.0, 1000) - transformed_samples = jax.vmap(p.transform)({'x': samples})['x'] + transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x'] # cut off the samples that are outside the support samples = samples[transformed_samples >= xmin] transformed_samples = transformed_samples[transformed_samples >= xmin] samples = samples[transformed_samples <= xmax] transformed_samples = transformed_samples[transformed_samples <= xmax] # log pdf of powerlaw - assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) + assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) # Test Pareto Transform func(-1.0) From 576ce7cd14cda25522c4dda9cfce906b2f4570f1 Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 18:20:29 -0400 Subject: [PATCH 068/172] Fix likelihood sampling issue due to parameter name transformation --- test/integration/test_GW150914.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 816647fa..2650cd1c 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -36,24 +36,25 @@ q_prior = UniformPrior( 0.125, 1.0, - parameter_names=["q"], + parameter_names=["q"], # Need name transformation in likelihood to work ) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +# Current likelihood sampling will fail and give nan because of large number dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) cos_iota_prior = UniformPrior( -1.0, 1.0, - parameter_names=["cos_iota"], + parameter_names=["cos_iota"], # Need name transformation in likelihood to work ) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) sin_dec_prior = UniformPrior( -1.0, 1.0, - parameter_names=["sin_dec"], + parameter_names=["sin_dec"], # Need name transformation in likelihood to work ) prior = CombinePrior( From 7625dbe04eab2c436d2ae27c91d2ee3148c98b0d Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 30 Jul 2024 14:51:56 -0400 Subject: [PATCH 069/172] Instead of defining univariate and multivariate, favor bijective versus nonbijective --- src/jimgw/transforms.py | 87 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 9 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4c1380cf..902b5b30 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -10,12 +10,10 @@ class Transform(ABC): """ Base class for transform. - - The idea of transform should be used on distribtuion, + The purpose of this class is purely for keeping name """ name_mapping: tuple[list[str], list[str]] - transform_func: Callable[[dict[str, Float]], dict[str, Float]] def __init__( self, @@ -23,6 +21,17 @@ def __init__( ): self.name_mapping = name_mapping + def propagate_name(self, x: list[str]) -> list[str]: + input_set = set(x) + from_set = set(self.name_mapping[0]) + to_set = set(self.name_mapping[1]) + return list(input_set - from_set | to_set) + +class BijectiveTransform(Transform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] + def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: return self.transform(x) @@ -62,13 +71,64 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: The transformed dictionary. """ raise NotImplementedError + + @abstractmethod + def inverse(self, y: dict[str, Float]) -> dict[str, Float]: + """ + Inverse transform the input y to original coordinate x. - def propagate_name(self, x: list[str]) -> list[str]: - input_set = set(x) - from_set = set(self.name_mapping[0]) - to_set = set(self.name_mapping[1]) - return list(input_set - from_set | to_set) + Parameters + ---------- + y : dict[str, Float] + The transformed dictionary. + + Returns + ------- + x : dict[str, Float] + The original dictionary. + """ + raise NotImplementedError + + @abstractmethod + def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + """ + Pull back the input y to original coordinate x and return the log Jacobian determinant. + + Parameters + ---------- + y : dict[str, Float] + The transformed dictionary. + + Returns + ------- + x : dict[str, Float] + The original dictionary. + log_det : Float + The log Jacobian determinant. + """ + raise NotImplementedError + +class NonBijectiveTransform(Transform): + + def __call__(self, x: dict[str, Float]) -> dict[str, Float]: + return self.forward(x) + + @abstractmethod + def forward(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Push forward the input x to transformed coordinate y. + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + raise NotImplementedError class UnivariateTransform(Transform): @@ -94,7 +154,7 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: return x -class ScaleTransform(UnivariateTransform): +class ScaleTransform(BijectiveTransform): scale: Float def __init__( @@ -105,6 +165,15 @@ def __init__( super().__init__(name_mapping) self.scale = scale self.transform_func = lambda x: x * self.scale + self.inverse_transform_func = lambda x: x / self.scale + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + input_params = x.pop(self.name_mapping[0][0]) + assert_rank(input_params, 0) + output_params = self.transform_func(input_params) + jacobian = jnp.log(jnp.abs(self.scale)) + x[self.name_mapping[1][0]] = output_params + return x, jacobian class OffsetTransform(UnivariateTransform): From 9ac722f57dbfed388bd0ab2aa2244f9463a5cc85 Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 30 Jul 2024 16:29:54 -0400 Subject: [PATCH 070/172] Adding some transform, should be working --- src/jimgw/transforms.py | 188 ++++++++++++++++++---------------------- 1 file changed, 84 insertions(+), 104 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 902b5b30..904fe225 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp from chex import assert_rank -from jaxtyping import Float +from jaxtyping import Float, Array class Transform(ABC): @@ -29,13 +29,12 @@ def propagate_name(self, x: list[str]) -> list[str]: class BijectiveTransform(Transform): - transform_func: Callable[[dict[str, Float]], dict[str, Float]] - inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] + transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + inverse_transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: return self.transform(x) - @abstractmethod def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. @@ -53,9 +52,12 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - raise NotImplementedError + input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) + output_params = self.transform_func(input_params) + jacobian = jnp.array(jax.jacfwd(self.transform_func)(input_params)) + jax.tree.map(lambda key, value: x.update({key: value}), self.name_mapping[1], output_params) + return x, jnp.log(jnp.linalg.det(jacobian)) - @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ Push forward the input x to transformed coordinate y. @@ -70,9 +72,11 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: y : dict[str, Float] The transformed dictionary. """ - raise NotImplementedError + input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) + output_params = self.transform_func(input_params) + jax.tree.map(lambda key, value: x.update({key: value}), self.name_mapping[1], output_params) + return x - @abstractmethod def inverse(self, y: dict[str, Float]) -> dict[str, Float]: """ Inverse transform the input y to original coordinate x. @@ -87,9 +91,12 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: x : dict[str, Float] The original dictionary. """ - raise NotImplementedError + output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) + input_params = self.inverse_transform_func(output_params) + jacobian = jnp.array(jax.jacfwd(self.inverse_transform_func)(output_params)) + jax.tree.map(lambda key, value: y.update({key: value}), self.name_mapping[0], input_params) + return y, jnp.log(jnp.linalg.det(jacobian)) - @abstractmethod def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Pull back the input y to original coordinate x and return the log Jacobian determinant. @@ -106,7 +113,10 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - raise NotImplementedError + output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) + input_params = self.inverse_transform_func(output_params) + jax.tree.map(lambda key, value: y.update({key: value}), self.name_mapping[0], input_params) + return y class NonBijectiveTransform(Transform): @@ -130,30 +140,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ raise NotImplementedError -class UnivariateTransform(Transform): - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - - def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - input_params = x.pop(self.name_mapping[0][0]) - assert_rank(input_params, 0) - output_params = self.transform_func(input_params) - jacobian = jax.jacfwd(self.transform_func)(input_params) - x[self.name_mapping[1][0]] = output_params - return x, jnp.log(jacobian) - - def forward(self, x: dict[str, Float]) -> dict[str, Float]: - input_params = x.pop(self.name_mapping[0][0]) - assert_rank(input_params, 0) - output_params = self.transform_func(input_params) - x[self.name_mapping[1][0]] = output_params - return x - - class ScaleTransform(BijectiveTransform): scale: Float @@ -164,19 +150,10 @@ def __init__( ): super().__init__(name_mapping) self.scale = scale - self.transform_func = lambda x: x * self.scale - self.inverse_transform_func = lambda x: x / self.scale - - def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - input_params = x.pop(self.name_mapping[0][0]) - assert_rank(input_params, 0) - output_params = self.transform_func(input_params) - jacobian = jnp.log(jnp.abs(self.scale)) - x[self.name_mapping[1][0]] = output_params - return x, jacobian - + self.transform_func = lambda x: [x[0] * self.scale] + self.inverse_transform_func = lambda x: [x[0] / self.scale] -class OffsetTransform(UnivariateTransform): +class OffsetTransform(BijectiveTransform): offset: Float def __init__( @@ -186,10 +163,11 @@ def __init__( ): super().__init__(name_mapping) self.offset = offset - self.transform_func = lambda x: x + self.offset + self.transform_func = lambda x: [x[0] + self.offset] + self.inverse_transform_func = lambda x: [x[0] - self.offset] -class LogitTransform(UnivariateTransform): +class LogitTransform(BijectiveTransform): """ Logit transform following @@ -205,10 +183,11 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + self.transform_func = lambda x: [1 / (1 + jnp.exp(-x[0]))] + self.inverse_transform_func = lambda x: [jnp.log(x[0] / (1 - x[0]))] -class ArcSineTransform(UnivariateTransform): +class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -225,56 +204,57 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) - - -class PowerLawTransform(UnivariateTransform): - """ - PowerLaw transformation - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - xmin: Float - xmax: Float - alpha: Float - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - alpha: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.alpha = alpha - self.transform_func = lambda x: ( - self.xmin ** (1.0 + self.alpha) - + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - - -class ParetoTransform(UnivariateTransform): - """ - Pareto transformation: Power law when alpha = -1 - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax / self.xmin) - ) + self.inverse_transform_func = lambda x: jnp.sin(x) + + +# class PowerLawTransform(UnivariateTransform): +# """ +# PowerLaw transformation +# Parameters +# ---------- +# name_mapping : tuple[list[str], list[str]] +# The name mapping between the input and output dictionary. +# """ + +# xmin: Float +# xmax: Float +# alpha: Float + +# def __init__( +# self, +# name_mapping: tuple[list[str], list[str]], +# xmin: Float, +# xmax: Float, +# alpha: Float, +# ): +# super().__init__(name_mapping) +# self.xmin = xmin +# self.xmax = xmax +# self.alpha = alpha +# self.transform_func = lambda x: ( +# self.xmin ** (1.0 + self.alpha) +# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) + + +# class ParetoTransform(UnivariateTransform): +# """ +# Pareto transformation: Power law when alpha = -1 +# Parameters +# ---------- +# name_mapping : tuple[list[str], list[str]] +# The name mapping between the input and output dictionary. +# """ + +# def __init__( +# self, +# name_mapping: tuple[list[str], list[str]], +# xmin: Float, +# xmax: Float, +# ): +# super().__init__(name_mapping) +# self.xmin = xmin +# self.xmax = xmax +# self.transform_func = lambda x: self.xmin * jnp.exp( +# x * jnp.log(self.xmax / self.xmin) +# ) From 935b31484e06ddbfa3b985e22a8303c0dc50b7a3 Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 30 Jul 2024 16:47:20 -0400 Subject: [PATCH 071/172] Refactor into NtoN and NtoM transform --- src/jimgw/prior.py | 5 ++-- src/jimgw/transforms.py | 52 ++++++++++++++++++++++++++++++----------- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 6a8f10ca..22985a8a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -8,6 +8,7 @@ from jimgw.transforms import ( Transform, + NtoNTransform, LogitTransform, ScaleTransform, OffsetTransform, @@ -149,7 +150,7 @@ class SequentialTransformPrior(Prior): """ base_prior: Prior - transforms: list[Transform] + transforms: list[NtoNTransform] def __repr__(self): return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" @@ -157,7 +158,7 @@ def __repr__(self): def __init__( self, base_prior: Prior, - transforms: list[Transform], + transforms: list[NtoNTransform], ): self.base_prior = base_prior diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 904fe225..4e62316e 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -27,13 +27,10 @@ def propagate_name(self, x: list[str]) -> list[str]: to_set = set(self.name_mapping[1]) return list(input_set - from_set | to_set) -class BijectiveTransform(Transform): - transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] - inverse_transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] +class NtoNTransform(Transform): - def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - return self.transform(x) + transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ @@ -55,7 +52,11 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) output_params = self.transform_func(input_params) jacobian = jnp.array(jax.jacfwd(self.transform_func)(input_params)) - jax.tree.map(lambda key, value: x.update({key: value}), self.name_mapping[1], output_params) + jax.tree.map( + lambda key, value: x.update({key: value}), + self.name_mapping[1], + output_params, + ) return x, jnp.log(jnp.linalg.det(jacobian)) def forward(self, x: dict[str, Float]) -> dict[str, Float]: @@ -74,9 +75,21 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) output_params = self.transform_func(input_params) - jax.tree.map(lambda key, value: x.update({key: value}), self.name_mapping[1], output_params) + jax.tree.map( + lambda key, value: x.update({key: value}), + self.name_mapping[1], + output_params, + ) return x - + + +class BijectiveTransform(NtoNTransform): + + inverse_transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + + def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + return self.transform(x) + def inverse(self, y: dict[str, Float]) -> dict[str, Float]: """ Inverse transform the input y to original coordinate x. @@ -94,7 +107,11 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) input_params = self.inverse_transform_func(output_params) jacobian = jnp.array(jax.jacfwd(self.inverse_transform_func)(output_params)) - jax.tree.map(lambda key, value: y.update({key: value}), self.name_mapping[0], input_params) + jax.tree.map( + lambda key, value: y.update({key: value}), + self.name_mapping[0], + input_params, + ) return y, jnp.log(jnp.linalg.det(jacobian)) def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: @@ -115,14 +132,21 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) input_params = self.inverse_transform_func(output_params) - jax.tree.map(lambda key, value: y.update({key: value}), self.name_mapping[0], input_params) + jax.tree.map( + lambda key, value: y.update({key: value}), + self.name_mapping[0], + input_params, + ) return y -class NonBijectiveTransform(Transform): - + +class NtoMTransform(Transform): + + transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " m_dim"]] + def __call__(self, x: dict[str, Float]) -> dict[str, Float]: return self.forward(x) - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -140,6 +164,7 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ raise NotImplementedError + class ScaleTransform(BijectiveTransform): scale: Float @@ -153,6 +178,7 @@ def __init__( self.transform_func = lambda x: [x[0] * self.scale] self.inverse_transform_func = lambda x: [x[0] / self.scale] + class OffsetTransform(BijectiveTransform): offset: Float From 685670a3c97b3b0300ce6a7ef6351bd56375c3e1 Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 30 Jul 2024 17:06:02 -0400 Subject: [PATCH 072/172] Fix bugs in ArcSine --- src/jimgw/prior.py | 4 ++-- src/jimgw/transforms.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 22985a8a..f291b23a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -13,8 +13,8 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - PowerLawTransform, - ParetoTransform, + # PowerLawTransform, + # ParetoTransform, ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4e62316e..ca759a07 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -229,8 +229,8 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: jnp.arcsin(x) - self.inverse_transform_func = lambda x: jnp.sin(x) + self.transform_func = lambda x: [jnp.arcsin(x[0])] + self.inverse_transform_func = lambda x: [jnp.sin(x[0])] # class PowerLawTransform(UnivariateTransform): From 39126f5cf51be98d7ac5d53ed3ba3ed4e37c5259 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Tue, 30 Jul 2024 17:41:48 -0400 Subject: [PATCH 073/172] Add inverse mass transform --- src/jimgw/single_event/utils.py | 54 +++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 3f0decce..870c70ae 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -61,7 +61,7 @@ def m1m2_to_Mq(m1: Float, m2: Float): return M_tot, q -def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): +def M_q_to_m1_m2(trans_M_tot: Float, trans_q: Float): """ Transforming the Total mass M and mass ratio q to the primary mass m1 and secondary mass m2. @@ -87,7 +87,7 @@ def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): return m1, m2 -def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: +def Mc_q_to_m1_m2(Mc: Float, q: Float) -> tuple[Float, Float]: """ Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and secondary mass m2. @@ -113,6 +113,56 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: return m1, m2 +def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass Mc + and mass ratio q. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + Mc : Float + Chirp mass. + q : Float + Mass ratio. + """ + M_tot = m1 + m2 + eta = m1 * m2 / M_tot ** 2 + Mc = M_tot * eta ** (3.0 / 5) + q = m2 / m1 + return Mc, q + + +def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the total mass M + and symmetric mass ratio eta. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M : Float + Total mass. + eta : Float + Symmetric mass ratio. + """ + M = m1 + m2 + eta = m1 * m2 / M ** 2 + return M, eta + + def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ Transforming the right ascension ra and declination dec to the polar angle From 46bd0442c6552a8967a0f9be388fa778a385bd4c Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 30 Jul 2024 23:27:24 -0400 Subject: [PATCH 074/172] Update sequential prior class and test_GW15014.py. --- src/jimgw/prior.py | 31 +++++++++++++++++-------------- test/integration/test_GW150914.py | 26 +++++++++----------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f291b23a..e23d0dfa 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -7,8 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped from jimgw.transforms import ( - Transform, - NtoNTransform, + BijectiveTransform, LogitTransform, ScaleTransform, OffsetTransform, @@ -61,7 +60,7 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: raise NotImplementedError - def log_prob(self, x: dict[str, Array]) -> Float: + def log_prob(self, z: dict[str, Array]) -> Float: raise NotImplementedError @@ -99,7 +98,7 @@ def sample( samples = jnp.log(samples / (1 - samples)) return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Float]) -> Float: + def log_prob(self, z: dict[str, Float]) -> Float: variable = x[self.parameter_names[0]] return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) @@ -139,7 +138,7 @@ def sample( samples = jax.random.normal(rng_key, (n_samples,)) return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Float]) -> Float: + def log_prob(self, z: dict[str, Float]) -> Float: variable = x[self.parameter_names[0]] return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) @@ -147,10 +146,12 @@ def log_prob(self, x: dict[str, Float]) -> Float: class SequentialTransformPrior(Prior): """ Transform a prior distribution by applying a sequence of transforms. + The space before the transform is named as x, + and the space after the transform is named as z """ base_prior: Prior - transforms: list[NtoNTransform] + transforms: list[BijectiveTransform] def __repr__(self): return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" @@ -158,7 +159,7 @@ def __repr__(self): def __init__( self, base_prior: Prior, - transforms: list[NtoNTransform], + transforms: list[BijectiveTransform], ): self.base_prior = base_prior @@ -174,14 +175,16 @@ def sample( output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - def log_prob(self, x: dict[str, Float]) -> Float: + def log_prob(self, z: dict[str, Float]) -> Float: """ - log_prob has to be evaluated in the space of the base_prior. + Evaluating the probability of the transformed variable z. + This is what flowMC should sample from """ - output = self.base_prior.log_prob(x) - for transform in self.transforms: - x, log_jacobian = transform.transform(x) + output = 0 + for transform in reversed(self.transforms): + z, log_jacobian = transform.inverse(z) output -= log_jacobian + output += self.base_prior.log_prob(z) return output def transform(self, x: dict[str, Float]) -> dict[str, Float]: @@ -223,10 +226,10 @@ def sample( output.update(prior.sample(subkey, n_samples)) return output - def log_prob(self, x: dict[str, Float]) -> Float: + def log_prob(self, z: dict[str, Float]) -> Float: output = 0.0 for prior in self.base_prior: - output += prior.log_prob(x) + output += prior.log_prob(z) return output diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 2650cd1c..deb3fb98 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD @@ -33,10 +33,10 @@ L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior( +eta_prior = UniformPrior( 0.125, - 1.0, - parameter_names=["q"], # Need name transformation in likelihood to work + 0.25, + parameter_names=["eta"], # Need name transformation in likelihood to work ) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) @@ -44,32 +44,24 @@ dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -cos_iota_prior = UniformPrior( - -1.0, - 1.0, - parameter_names=["cos_iota"], # Need name transformation in likelihood to work -) +iota_prior = CosinePrior(parameter_names=["iota"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -sin_dec_prior = UniformPrior( - -1.0, - 1.0, - parameter_names=["sin_dec"], # Need name transformation in likelihood to work -) +dec_prior = SinePrior(parameter_names=["dec"]) prior = CombinePrior( [ Mc_prior, - q_prior, + eta_prior, s1z_prior, s2z_prior, dL_prior, t_c_prior, phase_c_prior, - cos_iota_prior, + iota_prior, psi_prior, ra_prior, - sin_dec_prior, + dec_prior, ] ) likelihood = TransientLikelihoodFD( From 8e1441b389560bcb07f6a137705aba6db384d074 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 08:52:16 -0400 Subject: [PATCH 075/172] correct sign errors and minior bugs --- src/jimgw/prior.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e23d0dfa..13ab622f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -99,7 +99,7 @@ def sample( return self.add_name(samples[None]) def log_prob(self, z: dict[str, Float]) -> Float: - variable = x[self.parameter_names[0]] + variable = z[self.parameter_names[0]] return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) @@ -139,7 +139,7 @@ def sample( return self.add_name(samples[None]) def log_prob(self, z: dict[str, Float]) -> Float: - variable = x[self.parameter_names[0]] + variable = z[self.parameter_names[0]] return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) @@ -183,7 +183,7 @@ def log_prob(self, z: dict[str, Float]) -> Float: output = 0 for transform in reversed(self.transforms): z, log_jacobian = transform.inverse(z) - output -= log_jacobian + output += log_jacobian output += self.base_prior.log_prob(z) return output From dfdfffad0bc6847ad1519e7884426292b691ba80 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 09:43:00 -0400 Subject: [PATCH 076/172] Add mass transform --- src/jimgw/single_event/utils.py | 58 ++++++++++++++++----------------- src/jimgw/transforms.py | 44 ++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 30 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 870c70ae..4ee5c25e 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -87,14 +87,14 @@ def M_q_to_m1_m2(trans_M_tot: Float, trans_q: Float): return m1, m2 -def Mc_q_to_m1_m2(Mc: Float, q: Float) -> tuple[Float, Float]: +def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: """ - Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and + Transforming the chirp mass M_c and mass ratio q to the primary mass m1 and secondary mass m2. Parameters ---------- - Mc : Float + M_c : Float Chirp mass. q : Float Mass ratio. @@ -107,36 +107,36 @@ def Mc_q_to_m1_m2(Mc: Float, q: Float) -> tuple[Float, Float]: Secondary mass. """ eta = q / (1 + q) ** 2 - M_tot = Mc / eta ** (3.0 / 5) + M_tot = M_c / eta ** (3.0 / 5) m1 = M_tot / (1 + q) m2 = m1 * q return m1, m2 -def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: - """ - Transforming the primary mass m1 and secondary mass m2 to the chirp mass Mc - and mass ratio q. - - Parameters - ---------- - m1 : Float - Primary mass. - m2 : Float - Secondary mass. - - Returns - ------- - Mc : Float - Chirp mass. - q : Float - Mass ratio. - """ - M_tot = m1 + m2 - eta = m1 * m2 / M_tot ** 2 - Mc = M_tot * eta ** (3.0 / 5) - q = m2 / m1 - return Mc, q +def m1_m2_to_M_c_q(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c + and mass ratio q. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + """ + M_tot = m1 + m2 + eta = m1 * m2 / M_tot**2 + M_c = M_tot * eta ** (3.0 / 5) + q = m2 / m1 + return M_c, q def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: @@ -159,7 +159,7 @@ def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: Symmetric mass ratio. """ M = m1 + m2 - eta = m1 * m2 / M ** 2 + eta = m1 * m2 / M**2 return M, eta diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index ca759a07..e377afd2 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -3,9 +3,10 @@ import jax import jax.numpy as jnp -from chex import assert_rank from jaxtyping import Float, Array +from jimgw.single_event.utils import m1_m2_to_Mc_q, Mc_q_to_m1_m2 + class Transform(ABC): """ @@ -233,6 +234,47 @@ def __init__( self.inverse_transform_func = lambda x: [jnp.sin(x[0])] +class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): + """ + Transform component masses to chirp mass and mass ratio. + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + assert name_mapping == (["m_1", "m_2"], ["M_c", "q"]) + super().__init__(name_mapping) + self.transform_func = lambda x: m1_m2_to_Mc_q(x[0], x[1]) + self.inverse_transform_func = lambda x: Mc_q_to_m1_m2(x[0], x[1]) + + +def inverse(transform: BijectiveTransform) -> BijectiveTransform: + """ + Inverse the transform. + + Parameters + ---------- + transform : BijectiveTransform + The transform to be inverted. + + Returns + ------- + BijectiveTransform + The inverted transform. + """ + return BijectiveTransform( + name_mapping=transform.name_mapping, + transform_func=transform.inverse_transform_func, + inverse_transform_func=transform.transform_func, + ) + + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From dd03bf64f8522ff0f39800ea644c900360e90115 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Wed, 31 Jul 2024 09:50:39 -0400 Subject: [PATCH 077/172] Fixed powerLaw --- src/jimgw/prior.py | 4 +- src/jimgw/transforms.py | 107 +++++++++++++++++++++------------------- 2 files changed, 58 insertions(+), 53 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 13ab622f..aff1ebf5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,8 +12,8 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - # PowerLawTransform, - # ParetoTransform, + PowerLawTransform, + ParetoTransform, ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index ca759a07..08787071 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -233,54 +233,59 @@ def __init__( self.inverse_transform_func = lambda x: [jnp.sin(x[0])] -# class PowerLawTransform(UnivariateTransform): -# """ -# PowerLaw transformation -# Parameters -# ---------- -# name_mapping : tuple[list[str], list[str]] -# The name mapping between the input and output dictionary. -# """ - -# xmin: Float -# xmax: Float -# alpha: Float - -# def __init__( -# self, -# name_mapping: tuple[list[str], list[str]], -# xmin: Float, -# xmax: Float, -# alpha: Float, -# ): -# super().__init__(name_mapping) -# self.xmin = xmin -# self.xmax = xmax -# self.alpha = alpha -# self.transform_func = lambda x: ( -# self.xmin ** (1.0 + self.alpha) -# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) -# ) ** (1.0 / (1.0 + self.alpha)) - - -# class ParetoTransform(UnivariateTransform): -# """ -# Pareto transformation: Power law when alpha = -1 -# Parameters -# ---------- -# name_mapping : tuple[list[str], list[str]] -# The name mapping between the input and output dictionary. -# """ - -# def __init__( -# self, -# name_mapping: tuple[list[str], list[str]], -# xmin: Float, -# xmax: Float, -# ): -# super().__init__(name_mapping) -# self.xmin = xmin -# self.xmax = xmax -# self.transform_func = lambda x: self.xmin * jnp.exp( -# x * jnp.log(self.xmax / self.xmin) -# ) +class PowerLawTransform(UnivariateTransform): + """ + PowerLaw transformation + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + xmin: Float + xmax: Float + alpha: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + alpha: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.alpha = alpha + self.transform_func = lambda x: ( + self.xmin ** (1.0 + self.alpha) + + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) ** (1.0 / (1.0 + self.alpha)) + self.inverse_transform_func = lambda x: ( + (x ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + + +class ParetoTransform(BijectiveTransform): + """ + Pareto transformation: Power law when alpha = -1 + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.transform_func = lambda x: self.xmin * jnp.exp( + x * jnp.log(self.xmax / self.xmin) + ) + self.inverse_transform_func = lambda x: (jnp.log(x / self.xmin) / jnp.log(self.xmax / self.xmin)) From 735415e8f1895d0aa3bcefd2a3f94f1c951f207e Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Wed, 31 Jul 2024 09:57:46 -0400 Subject: [PATCH 078/172] Fixed powerLaw --- src/jimgw/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 08787071..7a016704 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -233,7 +233,7 @@ def __init__( self.inverse_transform_func = lambda x: [jnp.sin(x[0])] -class PowerLawTransform(UnivariateTransform): +class PowerLawTransform(BijectiveTransform): """ PowerLaw transformation Parameters From cec7447cbbb482a026612ce2bc77316395374247 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Wed, 31 Jul 2024 10:12:03 -0400 Subject: [PATCH 079/172] Fixed powerLaw --- src/jimgw/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 7a016704..e1aa50ba 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -259,10 +259,10 @@ def __init__( self.alpha = alpha self.transform_func = lambda x: ( self.xmin ** (1.0 + self.alpha) - + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + + x[0] * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) ) ** (1.0 / (1.0 + self.alpha)) self.inverse_transform_func = lambda x: ( - (x ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + (x[0] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) ) @@ -286,6 +286,6 @@ def __init__( self.xmin = xmin self.xmax = xmax self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax / self.xmin) + x[0] * jnp.log(self.xmax / self.xmin) ) - self.inverse_transform_func = lambda x: (jnp.log(x / self.xmin) / jnp.log(self.xmax / self.xmin)) + self.inverse_transform_func = lambda x: (jnp.log(x[0] / self.xmin) / jnp.log(self.xmax / self.xmin)) From 74e553d82510a43552d13438aa227edd9dba6bb4 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Wed, 31 Jul 2024 10:30:28 -0400 Subject: [PATCH 080/172] Fixed powerLaw --- src/jimgw/transforms.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index e1aa50ba..3071f964 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -257,14 +257,14 @@ def __init__( self.xmin = xmin self.xmax = xmax self.alpha = alpha - self.transform_func = lambda x: ( + self.transform_func = lambda x: [( self.xmin ** (1.0 + self.alpha) + x[0] * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - self.inverse_transform_func = lambda x: ( + ) ** (1.0 / (1.0 + self.alpha))] + self.inverse_transform_func = lambda x: [( (x[0] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) + )] class ParetoTransform(BijectiveTransform): @@ -285,7 +285,7 @@ def __init__( super().__init__(name_mapping) self.xmin = xmin self.xmax = xmax - self.transform_func = lambda x: self.xmin * jnp.exp( + self.transform_func = lambda x: [self.xmin * jnp.exp( x[0] * jnp.log(self.xmax / self.xmin) - ) - self.inverse_transform_func = lambda x: (jnp.log(x[0] / self.xmin) / jnp.log(self.xmax / self.xmin)) + )] + self.inverse_transform_func = lambda x: [(jnp.log(x[0] / self.xmin) / jnp.log(self.xmax / self.xmin))] From 9d87e58fa19fc4169ab8f4184e3c1738e817aa6f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 11:08:53 -0400 Subject: [PATCH 081/172] Add simplex transform --- src/jimgw/single_event/utils.py | 4 ++-- src/jimgw/transforms.py | 31 ++++++++++++++++--------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 4ee5c25e..3d07bb57 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -37,7 +37,7 @@ def inner_product( return 4.0 * jnp.real(trapezoid(integrand, dx=df)) -def m1m2_to_Mq(m1: Float, m2: Float): +def m1_m2_to_M_q(m1: Float, m2: Float): """ Transforming the primary mass m1 and secondary mass m2 to the Total mass M and mass ratio q. @@ -113,7 +113,7 @@ def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: return m1, m2 -def m1_m2_to_M_c_q(m1: Float, m2: Float) -> tuple[Float, Float]: +def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: """ Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c and mass ratio q. diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index e377afd2..be181fcd 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -196,7 +196,7 @@ def __init__( class LogitTransform(BijectiveTransform): """ - Logit transform following + Logit transformation Parameters ---------- @@ -254,25 +254,26 @@ def __init__( self.inverse_transform_func = lambda x: Mc_q_to_m1_m2(x[0], x[1]) -def inverse(transform: BijectiveTransform) -> BijectiveTransform: +class RectangleToTriangleTransform(BijectiveTransform): """ - Inverse the transform. + Transform a rectangle grid with bounds [0, 1] x [0, 1] to a triangle grid with vertices (0, 0), (1, 0), (0, 1), while preserving the area density. Parameters ---------- - transform : BijectiveTransform - The transform to be inverted. - - Returns - ------- - BijectiveTransform - The inverted transform. + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. """ - return BijectiveTransform( - name_mapping=transform.name_mapping, - transform_func=transform.inverse_transform_func, - inverse_transform_func=transform.transform_func, - ) + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: [ + 1 - jnp.sqrt(1 - x[0]), + x[1] * jnp.sqrt(1 - x[0]), + ] + self.inverse_transform_func = lambda x: [1 - (1 - x[0]) ** 2, x[1] / (1 - x[0])] # class PowerLawTransform(UnivariateTransform): From 52fa0eb28308aae155c566c677317b94645a83ea Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 11:10:23 -0400 Subject: [PATCH 082/172] Change transform to take dictionary as input for transform_func --- src/jimgw/transforms.py | 119 +++++++++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 39 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index ca759a07..13d66173 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -30,7 +30,7 @@ def propagate_name(self, x: list[str]) -> list[str]: class NtoNTransform(Transform): - transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + transform_func: Callable[[dict[str, Float]], dict[str, Float]] def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ @@ -49,15 +49,22 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) - output_params = self.transform_func(input_params) - jacobian = jnp.array(jax.jacfwd(self.transform_func)(input_params)) + x_copy = x.copy() + output_params = self.transform_func(x_copy) + jacobian = jax.jacfwd(self.transform_func)(x_copy) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log( + jnp.linalg.det(jacobian.reshape(int(jnp.sqrt(jacobian.size)), -1)) + ) jax.tree.map( - lambda key, value: x.update({key: value}), - self.name_mapping[1], - output_params, + lambda key: x_copy.pop(key), + self.name_mapping[0], ) - return x, jnp.log(jnp.linalg.det(jacobian)) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy, jacobian def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -73,19 +80,22 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: y : dict[str, Float] The transformed dictionary. """ - input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) - output_params = self.transform_func(input_params) + x_copy = x.copy() + output_params = self.transform_func(x_copy) jax.tree.map( - lambda key, value: x.update({key: value}), - self.name_mapping[1], - output_params, + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), ) - return x + return x_copy class BijectiveTransform(NtoNTransform): - inverse_transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: return self.transform(x) @@ -104,15 +114,22 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: x : dict[str, Float] The original dictionary. """ - output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) - input_params = self.inverse_transform_func(output_params) - jacobian = jnp.array(jax.jacfwd(self.inverse_transform_func)(output_params)) + y_copy = y.copy() + output_params = self.inverse_transform_func(y_copy) + jacobian = jax.jacfwd(self.inverse_transform_func)(y_copy) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log( + jnp.linalg.det(jacobian.reshape(int(jnp.sqrt(jacobian.size)), -1)) + ) jax.tree.map( - lambda key, value: y.update({key: value}), - self.name_mapping[0], - input_params, + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), ) - return y, jnp.log(jnp.linalg.det(jacobian)) + return y_copy, jacobian def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ @@ -130,14 +147,17 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) - input_params = self.inverse_transform_func(output_params) + y_copy = y.copy() + output_params = self.inverse_transform_func(y_copy) jax.tree.map( - lambda key, value: y.update({key: value}), - self.name_mapping[0], - input_params, + lambda key: y_copy.pop(key), + self.name_mapping[1], ) - return y + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy class NtoMTransform(Transform): @@ -175,8 +195,14 @@ def __init__( ): super().__init__(name_mapping) self.scale = scale - self.transform_func = lambda x: [x[0] * self.scale] - self.inverse_transform_func = lambda x: [x[0] / self.scale] + self.transform_func = lambda x: { + name_mapping[1][i]: x[name_mapping[0][i]] * self.scale + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: x[name_mapping[1][i]] / self.scale + for i in range(len(name_mapping[1])) + } class OffsetTransform(BijectiveTransform): @@ -189,9 +215,14 @@ def __init__( ): super().__init__(name_mapping) self.offset = offset - self.transform_func = lambda x: [x[0] + self.offset] - self.inverse_transform_func = lambda x: [x[0] - self.offset] - + self.transform_func = lambda x: { + name_mapping[1][i]: x[name_mapping[0][i]] + self.offset + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: x[name_mapping[1][i]] - self.offset + for i in range(len(name_mapping[1])) + } class LogitTransform(BijectiveTransform): """ @@ -209,9 +240,14 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: [1 / (1 + jnp.exp(-x[0]))] - self.inverse_transform_func = lambda x: [jnp.log(x[0] / (1 - x[0]))] - + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.log(x[name_mapping[0][i]] / (1 - x[name_mapping[0][i]])) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: 1 / (1 + jnp.exp(-x[name_mapping[1][i]])) + for i in range(len(name_mapping[1])) + } class ArcSineTransform(BijectiveTransform): """ @@ -229,9 +265,14 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: [jnp.arcsin(x[0])] - self.inverse_transform_func = lambda x: [jnp.sin(x[0])] - + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.arcsin(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.sin(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } # class PowerLawTransform(UnivariateTransform): # """ From 2aef04bc78147221b6d1abd65316447f49a00e50 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:11:23 -0400 Subject: [PATCH 083/172] Updated test_prior.py --- test/unit/test_prior.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 20d71de4..1cbed508 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -96,20 +96,16 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base'] - base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples}) - assert jnp.all(jnp.isfinite(base_log_p)) + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_p = jax.vmap(p.log_prob, [0])(samples) + assert jnp.all(jnp.isfinite(log_p)) # Check that the log_prob is correct in the support - samples = jnp.linspace(-10.0, 10.0, 1000) - transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x'] - # cut off the samples that are outside the support - samples = samples[transformed_samples >= xmin] - transformed_samples = transformed_samples[transformed_samples >= xmin] - samples = samples[transformed_samples <= xmax] - transformed_samples = transformed_samples[transformed_samples <= xmax] + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + standard_log_prob = powerlaw_log_pdf(samples['x'], alpha, xmin, xmax) # log pdf of powerlaw - assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) + assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4) # Test Pareto Transform func(-1.0) From 1e389bb1e19332576ece2d2911144a84e9dc16af Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 11:46:05 -0400 Subject: [PATCH 084/172] Add UniformComponentMassPrior --- src/jimgw/prior.py | 51 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 13ab622f..7880648d 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,8 +12,7 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - # PowerLawTransform, - # ParetoTransform, + RectangleToTriangleTransform, ) @@ -399,7 +398,53 @@ def __init__( @jaxtyped(typechecker=typechecker) -class UniformInComponentsChirpMassPrior(PowerLawPrior): +class UniformComponentMassPrior(SequentialTransformPrior): + """ + A prior in the range [xmin, xmax) for component masses which assumes the + component masses to be uniformly distributed. + """ + + xmin: float + xmax: float + + def __repr__(self): + return f"UniformComponentMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 2, "UniformComponentMassPrior needs to be 2D distributions" + self.xmax = xmax + self.xmin = xmin + super().__init__( + CombinePrior( + [ + UniformPrior(xmin, xmax, ["x_1"]), + UniformPrior(xmin, xmax, ["x_2"]), + ] + ), + [ + ScaleTransform( + ( + ["x_1", "x_2"], + [f"x_1/({xmax-xmin})", f"x_2/({xmax-xmin})"], + ), + 1 / (xmax - xmin), + ), + RectangleToTriangleTransform( + ( + [ + f"{self.parameter_names[0]}/({xmax-xmin})", + f"{self.parameter_names[1]}/({xmax-xmin})", + ], + [self.parameter_names[0], self.parameter_names[1]], + ) + ), + ], + ) + + +@jaxtyped(typechecker=typechecker) +class UniformComponentChirpMassPrior(PowerLawPrior): """ A prior in the range [xmin, xmax) for chirp mass which assumes the component masses to be uniformly distributed. From 0252392aacf4fa53e18a7d5169e57dac6bcd15bb Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 12:26:25 -0400 Subject: [PATCH 085/172] prior system should work with dictionary function now --- src/jimgw/jim.py | 10 +++------- test/integration/test_GW150914.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 1a694172..81d56836 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -32,12 +32,8 @@ def __init__( self.likelihood = likelihood self.prior = prior if parameter_names is None: - print("No parameter names provided. Will try to trace the prior.") - parents = [] - trace_prior_parent(prior, parents) - parameter_names = [] - for parent in parents: - parameter_names.extend(parent.parameter_names) + print("No parameter names provided. Using prior names.") + parameter_names = prior.parameter_names self.parameter_names = parameter_names seed = kwargs.get("seed", 0) @@ -79,7 +75,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) - prior = self.prior.log_prob(named_params) + prior = self.prior.log_prob(named_params) return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index deb3fb98..4028164f 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -106,4 +106,4 @@ strategies=[Adam_optimizer, "default"], ) -jim.sample(jax.random.PRNGKey(42)) +# jim.sample(jax.random.PRNGKey(42)) From cb75bb9878036038e0dc8733bf51bddaa5a57f8a Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 12:31:03 -0400 Subject: [PATCH 086/172] update transform --- src/jimgw/transforms.py | 22 ++++++++++++++-------- test/integration/test_GW150914.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 13d66173..bfabc2ad 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -32,6 +32,10 @@ class NtoNTransform(Transform): transform_func: Callable[[dict[str, Float]], dict[str, Float]] + @property + def n_dim(self) -> int: + return len(self.name_mapping[0]) + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. @@ -50,11 +54,12 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ x_copy = x.copy() - output_params = self.transform_func(x_copy) - jacobian = jax.jacfwd(self.transform_func)(x_copy) + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(int(jnp.sqrt(jacobian.size)), -1)) + jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) ) jax.tree.map( lambda key: x_copy.pop(key), @@ -115,11 +120,12 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: The original dictionary. """ y_copy = y.copy() - output_params = self.inverse_transform_func(y_copy) - jacobian = jax.jacfwd(self.inverse_transform_func)(y_copy) + transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) + output_params = self.inverse_transform_func(transform_params) + jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(int(jnp.sqrt(jacobian.size)), -1)) + jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) ) jax.tree.map( lambda key: y_copy.pop(key), @@ -241,11 +247,11 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: { - name_mapping[1][i]: jnp.log(x[name_mapping[0][i]] / (1 - x[name_mapping[0][i]])) + name_mapping[1][i]: 1 / (1 + jnp.exp(-x[name_mapping[0][i]])) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: 1 / (1 + jnp.exp(-x[name_mapping[1][i]])) + name_mapping[0][i]: jnp.log(x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]])) for i in range(len(name_mapping[1])) } diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 4028164f..deb3fb98 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -106,4 +106,4 @@ strategies=[Adam_optimizer, "default"], ) -# jim.sample(jax.random.PRNGKey(42)) +jim.sample(jax.random.PRNGKey(42)) From 561e628785e290ed4d07f49a1b3dd7010140457c Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 13:16:52 -0400 Subject: [PATCH 087/172] Move subclassing structure --- src/jimgw/jim.py | 29 ++++++++++---- src/jimgw/transforms.py | 86 +++++++++++++++-------------------------- 2 files changed, 54 insertions(+), 61 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 81d56836..08878293 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,6 +8,7 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior, trace_prior_parent +from jimgw.transforms import BijectiveTransform, NtoMTransform class Jim(object): @@ -19,22 +20,30 @@ class Jim(object): prior: Prior # Name of parameters to sample from - parameter_names: list[str] + sample_transforms: list[BijectiveTransform] + likelihood_transforms: list[NtoMTransform] sampler: Sampler def __init__( self, likelihood: LikelihoodBase, prior: Prior, - parameter_names: list[str] | None = None, + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], **kwargs, ): self.likelihood = likelihood self.prior = prior - if parameter_names is None: - print("No parameter names provided. Using prior names.") - parameter_names = prior.parameter_names - self.parameter_names = parameter_names + + self.sample_transforms = sample_transforms + self.likelihood_transforms = likelihood_transforms + + if len(sample_transforms) == 0: + print("No sample transforms provided. Using prior parameters as sampling parameters") + + if len(likelihood_transforms) == 0: + print("No likelihood transforms provided. Using prior parameters as likelihood parameters") + seed = kwargs.get("seed", 0) @@ -75,7 +84,13 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) - prior = self.prior.log_prob(named_params) + transform_jacobian = 0.0 + for transform in self.sample_transforms: + named_params, jacobian = transform.inverse(named_params) + transform_jacobian += jacobian + prior = self.prior.log_prob(named_params) + transform_jacobian + for transform in self.likelihood_transforms: + named_params = transform.forward(named_params) return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index bfabc2ad..fec18a79 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -28,18 +28,13 @@ def propagate_name(self, x: list[str]) -> list[str]: return list(input_set - from_set | to_set) -class NtoNTransform(Transform): +class NtoMTransform(Transform): transform_func: Callable[[dict[str, Float]], dict[str, Float]] - @property - def n_dim(self) -> int: - return len(self.name_mapping[0]) - - def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ - Transform the input x to transformed coordinate y and return the log Jacobian determinant. - This only works if the transform is a N -> N transform. + Push forward the input x to transformed coordinate y. Parameters ---------- @@ -50,17 +45,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: ------- y : dict[str, Float] The transformed dictionary. - log_det : Float - The log Jacobian determinant. """ x_copy = x.copy() - transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) - output_params = self.transform_func(transform_params) - jacobian = jax.jacfwd(self.transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) - ) + output_params = self.transform_func(x_copy) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -69,11 +56,21 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) - return x_copy, jacobian + return x_copy - def forward(self, x: dict[str, Float]) -> dict[str, Float]: + +class NtoNTransform(NtoMTransform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + @property + def n_dim(self) -> int: + return len(self.name_mapping[0]) + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ - Push forward the input x to transformed coordinate y. + Transform the input x to transformed coordinate y and return the log Jacobian determinant. + This only works if the transform is a N -> N transform. Parameters ---------- @@ -84,9 +81,15 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: ------- y : dict[str, Float] The transformed dictionary. + log_det : Float + The log Jacobian determinant. """ x_copy = x.copy() - output_params = self.transform_func(x_copy) + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -95,16 +98,13 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) - return x_copy + return x_copy, jacobian class BijectiveTransform(NtoNTransform): inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - return self.transform(x) - def inverse(self, y: dict[str, Float]) -> dict[str, Float]: """ Inverse transform the input y to original coordinate x. @@ -124,9 +124,7 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) - ) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], @@ -166,31 +164,6 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: return y_copy -class NtoMTransform(Transform): - - transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " m_dim"]] - - def __call__(self, x: dict[str, Float]) -> dict[str, Float]: - return self.forward(x) - - @abstractmethod - def forward(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Push forward the input x to transformed coordinate y. - - Parameters - ---------- - x : dict[str, Float] - The input dictionary. - - Returns - ------- - y : dict[str, Float] - The transformed dictionary. - """ - raise NotImplementedError - - class ScaleTransform(BijectiveTransform): scale: Float @@ -230,6 +203,7 @@ def __init__( for i in range(len(name_mapping[1])) } + class LogitTransform(BijectiveTransform): """ Logit transform following @@ -251,10 +225,13 @@ def __init__( for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.log(x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]])) + name_mapping[0][i]: jnp.log( + x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]]) + ) for i in range(len(name_mapping[1])) } + class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -280,6 +257,7 @@ def __init__( for i in range(len(name_mapping[1])) } + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 1aaf5ea91a6774ecc89314af3b953ac34c75539f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 13:59:44 -0400 Subject: [PATCH 088/172] Remove transformation --- src/jimgw/jim.py | 4 ++-- src/jimgw/transforms.py | 37 ++++++++----------------------------- 2 files changed, 10 insertions(+), 31 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 81d56836..03bcbc39 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from jimgw.base import LikelihoodBase -from jimgw.prior import Prior, trace_prior_parent +from jimgw.prior import Prior class Jim(object): @@ -75,7 +75,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) - prior = self.prior.log_prob(named_params) + prior = self.prior.log_prob(named_params) return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4a84d82f..7d1b5f3a 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -59,9 +59,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) - ) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -125,9 +123,7 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) - ) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], @@ -231,6 +227,7 @@ def __init__( for i in range(len(name_mapping[1])) } + class LogitTransform(BijectiveTransform): """ Logit transformation @@ -252,10 +249,13 @@ def __init__( for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.log(x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]])) + name_mapping[0][i]: jnp.log( + x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]]) + ) for i in range(len(name_mapping[1])) } + class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -281,6 +281,7 @@ def __init__( for i in range(len(name_mapping[1])) } + class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): """ Transform component masses to chirp mass and mass ratio. @@ -301,28 +302,6 @@ def __init__( self.inverse_transform_func = lambda x: Mc_q_to_m1_m2(x[0], x[1]) -class RectangleToTriangleTransform(BijectiveTransform): - """ - Transform a rectangle grid with bounds [0, 1] x [0, 1] to a triangle grid with vertices (0, 0), (1, 0), (0, 1), while preserving the area density. - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - self.transform_func = lambda x: [ - 1 - jnp.sqrt(1 - x[0]), - x[1] * jnp.sqrt(1 - x[0]), - ] - self.inverse_transform_func = lambda x: [1 - (1 - x[0]) ** 2, x[1] / (1 - x[0])] - - # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 5bfdda13161e954861774881701f49855ff7a267 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 14:01:02 -0400 Subject: [PATCH 089/172] Remove prior --- src/jimgw/prior.py | 47 ---------------------------------------------- 1 file changed, 47 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 7880648d..031a4133 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,7 +12,6 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - RectangleToTriangleTransform, ) @@ -397,52 +396,6 @@ def __init__( ) -@jaxtyped(typechecker=typechecker) -class UniformComponentMassPrior(SequentialTransformPrior): - """ - A prior in the range [xmin, xmax) for component masses which assumes the - component masses to be uniformly distributed. - """ - - xmin: float - xmax: float - - def __repr__(self): - return f"UniformComponentMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" - - def __init__(self, xmin: float, xmax: float, parameter_names: list[str]): - self.parameter_names = parameter_names - assert self.n_dim == 2, "UniformComponentMassPrior needs to be 2D distributions" - self.xmax = xmax - self.xmin = xmin - super().__init__( - CombinePrior( - [ - UniformPrior(xmin, xmax, ["x_1"]), - UniformPrior(xmin, xmax, ["x_2"]), - ] - ), - [ - ScaleTransform( - ( - ["x_1", "x_2"], - [f"x_1/({xmax-xmin})", f"x_2/({xmax-xmin})"], - ), - 1 / (xmax - xmin), - ), - RectangleToTriangleTransform( - ( - [ - f"{self.parameter_names[0]}/({xmax-xmin})", - f"{self.parameter_names[1]}/({xmax-xmin})", - ], - [self.parameter_names[0], self.parameter_names[1]], - ) - ), - ], - ) - - @jaxtyped(typechecker=typechecker) class UniformComponentChirpMassPrior(PowerLawPrior): """ From 77a6ad1785bb46a5d5f9163d9741e79d83ebacb9 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 14:03:04 -0400 Subject: [PATCH 090/172] Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-class' into transform --- src/jimgw/jim.py | 32 ++++++++++--- src/jimgw/transforms.py | 101 +++++++++++----------------------------- 2 files changed, 53 insertions(+), 80 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 03bcbc39..3186c2c4 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,6 +8,7 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior +from jimgw.transforms import BijectiveTransform, NtoMTransform class Jim(object): @@ -19,22 +20,33 @@ class Jim(object): prior: Prior # Name of parameters to sample from - parameter_names: list[str] + sample_transforms: list[BijectiveTransform] + likelihood_transforms: list[NtoMTransform] sampler: Sampler def __init__( self, likelihood: LikelihoodBase, prior: Prior, - parameter_names: list[str] | None = None, + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], **kwargs, ): self.likelihood = likelihood self.prior = prior - if parameter_names is None: - print("No parameter names provided. Using prior names.") - parameter_names = prior.parameter_names - self.parameter_names = parameter_names + + self.sample_transforms = sample_transforms + self.likelihood_transforms = likelihood_transforms + + if len(sample_transforms) == 0: + print( + "No sample transforms provided. Using prior parameters as sampling parameters" + ) + + if len(likelihood_transforms) == 0: + print( + "No likelihood transforms provided. Using prior parameters as likelihood parameters" + ) seed = kwargs.get("seed", 0) @@ -75,7 +87,13 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) - prior = self.prior.log_prob(named_params) + transform_jacobian = 0.0 + for transform in self.sample_transforms: + named_params, jacobian = transform.inverse(named_params) + transform_jacobian += jacobian + prior = self.prior.log_prob(named_params) + transform_jacobian + for transform in self.likelihood_transforms: + named_params = transform.forward(named_params) return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 7d1b5f3a..8de02f40 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,11 +1,9 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Callable import jax import jax.numpy as jnp -from jaxtyping import Float, Array - -from jimgw.single_event.utils import m1_m2_to_Mc_q, Mc_q_to_m1_m2 +from jaxtyping import Float class Transform(ABC): @@ -29,18 +27,13 @@ def propagate_name(self, x: list[str]) -> list[str]: return list(input_set - from_set | to_set) -class NtoNTransform(Transform): +class NtoMTransform(Transform): transform_func: Callable[[dict[str, Float]], dict[str, Float]] - @property - def n_dim(self) -> int: - return len(self.name_mapping[0]) - - def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ - Transform the input x to transformed coordinate y and return the log Jacobian determinant. - This only works if the transform is a N -> N transform. + Push forward the input x to transformed coordinate y. Parameters ---------- @@ -51,15 +44,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: ------- y : dict[str, Float] The transformed dictionary. - log_det : Float - The log Jacobian determinant. """ x_copy = x.copy() - transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) - output_params = self.transform_func(transform_params) - jacobian = jax.jacfwd(self.transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + output_params = self.transform_func(x_copy) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -68,11 +55,21 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) - return x_copy, jacobian + return x_copy - def forward(self, x: dict[str, Float]) -> dict[str, Float]: + +class NtoNTransform(NtoMTransform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + @property + def n_dim(self) -> int: + return len(self.name_mapping[0]) + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ - Push forward the input x to transformed coordinate y. + Transform the input x to transformed coordinate y and return the log Jacobian determinant. + This only works if the transform is a N -> N transform. Parameters ---------- @@ -83,9 +80,15 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: ------- y : dict[str, Float] The transformed dictionary. + log_det : Float + The log Jacobian determinant. """ x_copy = x.copy() - output_params = self.transform_func(x_copy) + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -94,16 +97,13 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) - return x_copy + return x_copy, jacobian class BijectiveTransform(NtoNTransform): inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - return self.transform(x) - def inverse(self, y: dict[str, Float]) -> dict[str, Float]: """ Inverse transform the input y to original coordinate x. @@ -163,31 +163,6 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: return y_copy -class NtoMTransform(Transform): - - transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " m_dim"]] - - def __call__(self, x: dict[str, Float]) -> dict[str, Float]: - return self.forward(x) - - @abstractmethod - def forward(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Push forward the input x to transformed coordinate y. - - Parameters - ---------- - x : dict[str, Float] - The input dictionary. - - Returns - ------- - y : dict[str, Float] - The transformed dictionary. - """ - raise NotImplementedError - - class ScaleTransform(BijectiveTransform): scale: Float @@ -230,7 +205,7 @@ def __init__( class LogitTransform(BijectiveTransform): """ - Logit transformation + Logit transform following Parameters ---------- @@ -282,26 +257,6 @@ def __init__( } -class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): - """ - Transform component masses to chirp mass and mass ratio. - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - assert name_mapping == (["m_1", "m_2"], ["M_c", "q"]) - super().__init__(name_mapping) - self.transform_func = lambda x: m1_m2_to_Mc_q(x[0], x[1]) - self.inverse_transform_func = lambda x: Mc_q_to_m1_m2(x[0], x[1]) - - # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 2401df8a8acf989305c158051b887cc5620b3218 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 16:02:55 -0400 Subject: [PATCH 091/172] Add mass transform --- src/jimgw/single_event/utils.py | 40 +++++++++++++++++++ src/jimgw/transforms.py | 70 +++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 3d07bb57..aaf4b947 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -163,6 +163,46 @@ def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: return M, eta +def Mc_q_to_eta(M_c: Float, q: Float) -> Float: + """ + Transforming the chirp mass M_c and mass ratio q to the symmetric mass ratio eta. + + Parameters + ---------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + + Returns + ------- + eta : Float + Symmetric mass ratio. + """ + eta = q / (1 + q) ** 2 + return eta + + +def eta_to_q(eta: Float) -> Float: + """ + Transforming the symmetric mass ratio eta to the mass ratio q. + + Copied and modified from bilby/gw/conversion.py + + Parameters + ---------- + eta : Float + Symmetric mass ratio. + + Returns + ------- + q : Float + Mass ratio. + """ + temp = 1 / eta / 2 - 1 + return temp - (temp**2 - 1) ** 0.5 + + def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ Transforming the right ascension ra and declination dec to the polar angle diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 8de02f40..2e80e3c5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -5,6 +5,8 @@ import jax.numpy as jnp from jaxtyping import Float +from jimgw.single_event.utils import Mc_q_to_m1_m2, m1_m2_to_Mc_q, Mc_q_to_eta, eta_to_q + class Transform(ABC): """ @@ -257,6 +259,74 @@ def __init__( } +class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to component masses + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + m1, m2 = Mc_q_to_m1_m2(Mc, q) + return {name_mapping[1][0]: m1, name_mapping[1][1]: m2} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1 = x[name_mapping[1][0]] + m2 = x[name_mapping[1][1]] + Mc, q = m1_m2_to_Mc_q(m1, m2) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + +class ChirpMassMassRatioToChirpMassSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to chirp mass and symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + eta = Mc_q_to_eta(Mc, q) + return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta} + + self.transform_func = named_transform + + def named_inverse_transform(x): + Mc = x[name_mapping[1][0]] + eta = x[name_mapping[1][1]] + q = eta_to_q(Mc, eta) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From ff35a82f0c4db96f64507ef5bcf4b06a58738000 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 16:24:48 -0400 Subject: [PATCH 092/172] Add Bound transforming transform --- src/jimgw/transforms.py | 112 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index fec18a79..57e787a5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -4,7 +4,8 @@ import jax import jax.numpy as jnp from chex import assert_rank -from jaxtyping import Float, Array +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped class Transform(ABC): @@ -258,6 +259,115 @@ def __init__( } +@jaxtyped(typechecker=typechecker) +class BoundToBound(BijectiveTransform): + + """ + Bound to bound transformation + """ + + original_lower_bound: Float[Array, " n_dim"] + original_upper_bound: Float[Array, " n_dim"] + target_lower_bound: Float[Array, " n_dim"] + target_upper_bound: Float[Array, " n_dim"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float[Array, " n_dim"], + original_upper_bound: Float[Array, " n_dim"], + target_lower_bound: Float[Array, " n_dim"], + target_upper_bound: Float[Array, " n_dim"], + ): + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + self.target_lower_bound = target_lower_bound + self.target_upper_bound = target_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: (x[name_mapping[0][i]] - self.original_lower_bound[i]) + * (self.target_upper_bound[i] - self.target_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) + + self.target_lower_bound[i] + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (x[name_mapping[1][i]] - self.target_lower_bound[i]) + * (self.original_upper_bound[i] - self.original_lower_bound[i]) + / (self.target_upper_bound[i] - self.target_lower_bound[i]) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + +class BoundToUnbound(BijectiveTransform): + """ + Bound to unbound transformation + """ + + original_lower_bound: Float + original_upper_bound: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, + original_upper_bound: Float, + ): + + def logit(x): + return jnp.log(x / (1 - x)) + + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: logit( + (x[name_mapping[0][i]] - self.original_lower_bound) + / (self.original_upper_bound - self.original_lower_bound) + ) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + self.original_upper_bound - self.original_lower_bound + ) + / ( + 1 + + jnp.exp(-x[name_mapping[1][i]]) + ) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + +class SingleSidedUnboundTransform(BijectiveTransform): + """ + Unbound upper limit transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 47af9cfb2f70ac18e3547cd2f1be537f2e83aa57 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Wed, 31 Jul 2024 23:08:44 -0400 Subject: [PATCH 093/172] no nans now, but the code can be consolidate --- src/jimgw/jim.py | 13 +++++++++++-- src/jimgw/transforms.py | 5 +++-- test/integration/test_GW150914.py | 17 +++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 08878293..39867e9f 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -22,6 +22,7 @@ class Jim(object): # Name of parameters to sample from sample_transforms: list[BijectiveTransform] likelihood_transforms: list[NtoMTransform] + parameter_names: list[str] sampler: Sampler def __init__( @@ -37,9 +38,14 @@ def __init__( self.sample_transforms = sample_transforms self.likelihood_transforms = likelihood_transforms + self.parameter_names = prior.parameter_names if len(sample_transforms) == 0: print("No sample transforms provided. Using prior parameters as sampling parameters") + else: + print("Using sample transforms") + for transform in sample_transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) if len(likelihood_transforms) == 0: print("No likelihood transforms provided. Using prior parameters as likelihood parameters") @@ -91,12 +97,15 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - return self.likelihood.evaluate(named_params, data) + prior + named_params = jax.tree.map(lambda x:x[0], named_params) # This [0] should be consolidate + return self.likelihood.evaluate(named_params, data) + prior[0] # This prior [0] should be consolidate def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + for transform in self.sample_transforms: + initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[0] # This [0] should be consolidate self.Sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 57e787a5..b832ccdc 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -300,6 +300,7 @@ def __init__( for i in range(len(name_mapping[1])) } +@jaxtyped(typechecker=typechecker) class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -319,8 +320,8 @@ def logit(x): return jnp.log(x / (1 - x)) super().__init__(name_mapping) - self.original_lower_bound = original_lower_bound - self.original_upper_bound = original_upper_bound + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.original_upper_bound = jnp.atleast_1d(original_upper_bound) self.transform_func = lambda x: { name_mapping[1][i]: logit( diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index deb3fb98..d82e8d7a 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,6 +8,7 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -64,6 +65,21 @@ dec_prior, ] ) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) +] + likelihood = TransientLikelihoodFD( [H1, L1], waveform=RippleIMRPhenomD(), @@ -88,6 +104,7 @@ jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, From b5e06a6972264597a03eb7f1afd2ab42d1e1eabd Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 10:19:40 -0400 Subject: [PATCH 094/172] Solve conflict --- src/jimgw/transforms.py | 107 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 2e80e3c5..b405f066 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -3,7 +3,8 @@ import jax import jax.numpy as jnp -from jaxtyping import Float +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped from jimgw.single_event.utils import Mc_q_to_m1_m2, m1_m2_to_Mc_q, Mc_q_to_eta, eta_to_q @@ -259,6 +260,110 @@ def __init__( } +@jaxtyped(typechecker=typechecker) +class BoundToBound(BijectiveTransform): + """ + Bound to bound transformation + """ + + original_lower_bound: Float[Array, " n_dim"] + original_upper_bound: Float[Array, " n_dim"] + target_lower_bound: Float[Array, " n_dim"] + target_upper_bound: Float[Array, " n_dim"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float[Array, " n_dim"], + original_upper_bound: Float[Array, " n_dim"], + target_lower_bound: Float[Array, " n_dim"], + target_upper_bound: Float[Array, " n_dim"], + ): + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + self.target_lower_bound = target_lower_bound + self.target_upper_bound = target_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: (x[name_mapping[0][i]] - self.original_lower_bound[i]) + * (self.target_upper_bound[i] - self.target_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) + + self.target_lower_bound[i] + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (x[name_mapping[1][i]] - self.target_lower_bound[i]) + * (self.original_upper_bound[i] - self.original_lower_bound[i]) + / (self.target_upper_bound[i] - self.target_lower_bound[i]) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + + +class BoundToUnbound(BijectiveTransform): + """ + Bound to unbound transformation + """ + + original_lower_bound: Float + original_upper_bound: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, + original_upper_bound: Float, + ): + + def logit(x): + return jnp.log(x / (1 - x)) + + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: logit( + (x[name_mapping[0][i]] - self.original_lower_bound) + / (self.original_upper_bound - self.original_lower_bound) + ) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) + / (1 + jnp.exp(-x[name_mapping[1][i]])) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + + +class SingleSidedUnboundTransform(BijectiveTransform): + """ + Unbound upper limit transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): """ Transform chirp mass and mass ratio to component masses From 255fad34c3a05c1031dd3cda58e67e1e4f45dac5 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:29:56 -0400 Subject: [PATCH 095/172] Fixed powerLaw --- src/jimgw/jim.py | 11 +++++--- src/jimgw/transforms.py | 57 ++++++++++++++++++++++++----------------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 08878293..3186c2c4 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from jimgw.base import LikelihoodBase -from jimgw.prior import Prior, trace_prior_parent +from jimgw.prior import Prior from jimgw.transforms import BijectiveTransform, NtoMTransform @@ -39,11 +39,14 @@ def __init__( self.likelihood_transforms = likelihood_transforms if len(sample_transforms) == 0: - print("No sample transforms provided. Using prior parameters as sampling parameters") + print( + "No sample transforms provided. Using prior parameters as sampling parameters" + ) if len(likelihood_transforms) == 0: - print("No likelihood transforms provided. Using prior parameters as likelihood parameters") - + print( + "No likelihood transforms provided. Using prior parameters as likelihood parameters" + ) seed = kwargs.get("seed", 0) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 88af8cb4..fa2ead90 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,9 +1,8 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Callable import jax import jax.numpy as jnp -from chex import assert_rank from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped @@ -261,7 +260,6 @@ def __init__( @jaxtyped(typechecker=typechecker) class BoundToBound(BijectiveTransform): - """ Bound to bound transformation """ @@ -300,6 +298,7 @@ def __init__( for i in range(len(name_mapping[1])) } + class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -314,7 +313,7 @@ def __init__( original_lower_bound: Float, original_upper_bound: Float, ): - + def logit(x): return jnp.log(x / (1 - x)) @@ -330,17 +329,13 @@ def logit(x): for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: ( - self.original_upper_bound - self.original_lower_bound - ) - / ( - 1 - + jnp.exp(-x[name_mapping[1][i]]) - ) + name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) + / (1 + jnp.exp(-x[name_mapping[1][i]])) + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } + class SingleSidedUnboundTransform(BijectiveTransform): """ Unbound upper limit transformation @@ -367,7 +362,6 @@ def __init__( } - class PowerLawTransform(BijectiveTransform): """ PowerLaw transformation @@ -392,14 +386,22 @@ def __init__( self.xmin = xmin self.xmax = xmax self.alpha = alpha - self.transform_func = lambda x: [( - self.xmin ** (1.0 + self.alpha) - + x[0] * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha))] - self.inverse_transform_func = lambda x: [( - (x[0] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - )] + self.transform_func = lambda x: { + name_mapping[1][i]: ( + self.xmin ** (1.0 + self.alpha) + + x[0] + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + ** (1.0 / (1.0 + self.alpha)) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + (x[0] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + for i in range(len(name_mapping[1])) + } class ParetoTransform(BijectiveTransform): @@ -420,7 +422,14 @@ def __init__( super().__init__(name_mapping) self.xmin = xmin self.xmax = xmax - self.transform_func = lambda x: [self.xmin * jnp.exp( - x[0] * jnp.log(self.xmax / self.xmin) - )] - self.inverse_transform_func = lambda x: [(jnp.log(x[0] / self.xmin) / jnp.log(self.xmax / self.xmin))] + self.transform_func = lambda x: { + name_mapping[1][i]: self.xmin + * jnp.exp(x[0] * jnp.log(self.xmax / self.xmin)) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + jnp.log(x[0] / self.xmin) / jnp.log(self.xmax / self.xmin) + ) + for i in range(len(name_mapping[1])) + } From fcdc120014c2bc07b736a3b48b10dfcad33fb498 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:36:31 -0400 Subject: [PATCH 096/172] Fixed powerLaw --- src/jimgw/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index fa2ead90..9c7035d4 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -389,7 +389,7 @@ def __init__( self.transform_func = lambda x: { name_mapping[1][i]: ( self.xmin ** (1.0 + self.alpha) - + x[0] + + x[name_mapping[0][i]] * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) ) ** (1.0 / (1.0 + self.alpha)) @@ -397,7 +397,7 @@ def __init__( } self.inverse_transform_func = lambda x: { name_mapping[0][i]: ( - (x[0] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + (x[name_mapping[1][i]] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) ) for i in range(len(name_mapping[1])) @@ -424,12 +424,12 @@ def __init__( self.xmax = xmax self.transform_func = lambda x: { name_mapping[1][i]: self.xmin - * jnp.exp(x[0] * jnp.log(self.xmax / self.xmin)) + * jnp.exp(x[name_mapping[0][i]] * jnp.log(self.xmax / self.xmin)) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { name_mapping[0][i]: ( - jnp.log(x[0] / self.xmin) / jnp.log(self.xmax / self.xmin) + jnp.log(x[name_mapping[1][i]] / self.xmin) / jnp.log(self.xmax / self.xmin) ) for i in range(len(name_mapping[1])) } From 06eb3ad5b430b168e0647cc8744267bf09575632 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:39:09 -0400 Subject: [PATCH 097/172] Reformatted --- src/jimgw/transforms.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 9c7035d4..cfd03b3f 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -397,7 +397,10 @@ def __init__( } self.inverse_transform_func = lambda x: { name_mapping[0][i]: ( - (x[name_mapping[1][i]] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ( + x[name_mapping[1][i]] ** (1.0 + self.alpha) + - self.xmin ** (1.0 + self.alpha) + ) / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) ) for i in range(len(name_mapping[1])) @@ -429,7 +432,8 @@ def __init__( } self.inverse_transform_func = lambda x: { name_mapping[0][i]: ( - jnp.log(x[name_mapping[1][i]] / self.xmin) / jnp.log(self.xmax / self.xmin) + jnp.log(x[name_mapping[1][i]] / self.xmin) + / jnp.log(self.xmax / self.xmin) ) for i in range(len(name_mapping[1])) } From 8e5b326c295f94f2390ab40427dd63455f2839b0 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 11:06:30 -0400 Subject: [PATCH 098/172] Add sky position transform --- src/jimgw/single_event/utils.py | 143 +++++++++++++++++++++++++------- src/jimgw/transforms.py | 47 ++++++++++- 2 files changed, 155 insertions(+), 35 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index aaf4b947..2373c5be 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -203,32 +203,6 @@ def eta_to_q(eta: Float) -> Float: return temp - (temp**2 - 1) ** 0.5 -def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: - """ - Transforming the right ascension ra and declination dec to the polar angle - theta and azimuthal angle phi. - - Parameters - ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. - - Returns - ------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. - """ - phi = ra - gmst - theta = jnp.pi / 2 - dec - return theta, phi - - def euler_rotation(delta_x: Float[Array, " 3"]): """ Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x @@ -239,11 +213,10 @@ def euler_rotation(delta_x: Float[Array, " 3"]): Copied and modified from bilby-cython/geometry.pyx """ - norm = jnp.power( - delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5 - ) + norm = jnp.linalg.vector_norm(delta_x) + cos_beta = delta_x[2] / norm - sin_beta = jnp.power(1 - cos_beta**2, 0.5) + sin_beta = jnp.sqrt(1 - cos_beta**2) alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0]) gamma = jnp.atan2(delta_x[1], delta_x[0]) @@ -318,7 +291,7 @@ def zenith_azimuth_to_theta_phi( + rotation[0][2] * cos_zenith, ) + 2 * jnp.pi, - (2 * jnp.pi), + 2 * jnp.pi, ) return theta, phi @@ -345,6 +318,7 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F """ ra = phi + gmst dec = jnp.pi / 2 - theta + ra = ra % (2 * jnp.pi) return ra, dec @@ -376,10 +350,115 @@ def zenith_azimuth_to_ra_dec( """ theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) - ra = ra % (2 * jnp.pi) return ra, dec +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + phi = ra - gmst + theta = jnp.pi / 2 - dec + phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) + return theta, phi + + +def theta_phi_to_zenith_azimuth( + theta: Float, phi: Float, delta_x: Float[Array, " 3"] +) -> tuple[Float, Float]: + """ + Transforming the polar angle and azimuthal angle to the zenith angle and azimuthal angle. + + Parameters + ---------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + delta_x : Float + The vector pointing from the first detector to the second detector. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + """ + sin_theta = jnp.sin(theta) + cos_theta = jnp.cos(theta) + sin_phi = jnp.sin(phi) + cos_phi = jnp.cos(phi) + + rotation = euler_rotation(delta_x) + rotation = jnp.linalg.inv(rotation) + + zenith = jnp.acos( + rotation[2][0] * sin_theta * cos_phi + + rotation[2][1] * sin_theta * sin_phi + + rotation[2][2] * cos_theta + ) + azimuth = jnp.fmod( + jnp.atan2( + rotation[1][0] * sin_theta * cos_phi + + rotation[1][1] * sin_theta * sin_phi + + rotation[1][2] * cos_theta, + rotation[0][0] * sin_theta * cos_phi + + rotation[0][1] * sin_theta * sin_phi + + rotation[0][2] * cos_theta, + ) + + 2 * jnp.pi, + 2 * jnp.pi, + ) + return zenith, azimuth + + +def ra_dec_to_zenith_azimuth( + ra: Float, dec: Float, gmst: Float, delta_x: Float[Array, " 3"] +) -> tuple[Float, Float]: + """ + Transforming the right ascension and declination to the zenith angle and azimuthal angle. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + delta_x : Float + The vector pointing from the first detector to the second detector. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + """ + theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) + zenith, azimuth = theta_phi_to_zenith_azimuth(theta, phi, delta_x) + return zenith, azimuth + + def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: """ A numerically stable method to evaluate log of diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 5d0d1643..d644c8e1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -6,7 +6,14 @@ from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped -from jimgw.single_event.utils import Mc_q_to_m1_m2, m1_m2_to_Mc_q, Mc_q_to_eta, eta_to_q +from jimgw.single_event.utils import ( + Mc_q_to_m1_m2, + m1_m2_to_Mc_q, + Mc_q_to_eta, + eta_to_q, + ra_dec_to_zenith_azimuth, + zenith_azimuth_to_ra_dec, +) class Transform(ABC): @@ -300,7 +307,7 @@ def __init__( for i in range(len(name_mapping[1])) } - + class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -315,7 +322,7 @@ def __init__( original_lower_bound: Float, original_upper_bound: Float, ): - + def logit(x): return jnp.log(x / (1 - x)) @@ -432,6 +439,40 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): + """ + Transform sky frame to detector frame sky position + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + ra = x[name_mapping[0][0]] + dec = x[name_mapping[0][1]] + zenith, azimuth = ra_dec_to_zenith_azimuth(ra, dec) + return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth} + + self.transform_func = named_transform + + def named_inverse_transform(x): + zenith = x[name_mapping[1][0]] + azimuth = x[name_mapping[1][1]] + ra, dec = zenith_azimuth_to_ra_dec(zenith, azimuth) + return {name_mapping[0][0]: ra, name_mapping[0][1]: dec} + + self.inverse_transform_func = named_inverse_transform + + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 10d51b2b05723b4e0eb3e9b0e3bf2342b841b535 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 11:21:06 -0400 Subject: [PATCH 099/172] Modify sky position transform --- src/jimgw/single_event/utils.py | 64 ++++----------------------------- src/jimgw/transforms.py | 19 ++++++++-- 2 files changed, 23 insertions(+), 60 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 2373c5be..705857fe 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -245,8 +245,8 @@ def euler_rotation(delta_x: Float[Array, " 3"]): return rotation -def zenith_azimuth_to_theta_phi( - zenith: Float, azimuth: Float, delta_x: Float[Array, " 3"] +def angle_rotation( + zenith: Float, azimuth: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame. @@ -274,8 +274,6 @@ def zenith_azimuth_to_theta_phi( sin_zenith = jnp.sin(zenith) cos_zenith = jnp.cos(zenith) - rotation = euler_rotation(delta_x) - theta = jnp.acos( rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth @@ -323,7 +321,7 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F def zenith_azimuth_to_ra_dec( - zenith: Float, azimuth: Float, gmst: Float, delta_x: Float[Array, " 3"] + zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. @@ -348,7 +346,7 @@ def zenith_azimuth_to_ra_dec( dec : Float Declination. """ - theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + theta, phi = angle_rotation(zenith, azimuth, rotation) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) return ra, dec @@ -380,58 +378,8 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa return theta, phi -def theta_phi_to_zenith_azimuth( - theta: Float, phi: Float, delta_x: Float[Array, " 3"] -) -> tuple[Float, Float]: - """ - Transforming the polar angle and azimuthal angle to the zenith angle and azimuthal angle. - - Parameters - ---------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. - delta_x : Float - The vector pointing from the first detector to the second detector. - - Returns - ------- - zenith : Float - Zenith angle. - azimuth : Float - Azimuthal angle. - """ - sin_theta = jnp.sin(theta) - cos_theta = jnp.cos(theta) - sin_phi = jnp.sin(phi) - cos_phi = jnp.cos(phi) - - rotation = euler_rotation(delta_x) - rotation = jnp.linalg.inv(rotation) - - zenith = jnp.acos( - rotation[2][0] * sin_theta * cos_phi - + rotation[2][1] * sin_theta * sin_phi - + rotation[2][2] * cos_theta - ) - azimuth = jnp.fmod( - jnp.atan2( - rotation[1][0] * sin_theta * cos_phi - + rotation[1][1] * sin_theta * sin_phi - + rotation[1][2] * cos_theta, - rotation[0][0] * sin_theta * cos_phi - + rotation[0][1] * sin_theta * sin_phi - + rotation[0][2] * cos_theta, - ) - + 2 * jnp.pi, - 2 * jnp.pi, - ) - return zenith, azimuth - - def ra_dec_to_zenith_azimuth( - ra: Float, dec: Float, gmst: Float, delta_x: Float[Array, " 3"] + ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the right ascension and declination to the zenith angle and azimuthal angle. @@ -455,7 +403,7 @@ def ra_dec_to_zenith_azimuth( Azimuthal angle. """ theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) - zenith, azimuth = theta_phi_to_zenith_azimuth(theta, phi, delta_x) + zenith, azimuth = angle_rotation(theta, phi, rotation) return zenith, azimuth diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d644c8e1..5a889569 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -13,6 +13,7 @@ eta_to_q, ra_dec_to_zenith_azimuth, zenith_azimuth_to_ra_dec, + euler_rotation, ) @@ -450,16 +451,28 @@ class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): """ + gmst: Float + rotation: Float[Array, " 3 3"] + rotation_inv: Float[Array, " 3 3"] + def __init__( self, name_mapping: tuple[list[str], list[str]], + gmst: Float, + delta_x: Float, ): super().__init__(name_mapping) + self.gmst = gmst + self.rotation = euler_rotation(delta_x) + self.rotation_inv = jnp.linalg.inv(self.rotation) + def named_transform(x): ra = x[name_mapping[0][0]] dec = x[name_mapping[0][1]] - zenith, azimuth = ra_dec_to_zenith_azimuth(ra, dec) + zenith, azimuth = ra_dec_to_zenith_azimuth( + ra, dec, self.gmst, self.rotation + ) return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth} self.transform_func = named_transform @@ -467,7 +480,9 @@ def named_transform(x): def named_inverse_transform(x): zenith = x[name_mapping[1][0]] azimuth = x[name_mapping[1][1]] - ra, dec = zenith_azimuth_to_ra_dec(zenith, azimuth) + ra, dec = zenith_azimuth_to_ra_dec( + zenith, azimuth, self.gmst, self.rotation_inv + ) return {name_mapping[0][0]: ra, name_mapping[0][1]: dec} self.inverse_transform_func = named_inverse_transform From 8368d00bd9a5d3a41e670036bdcfcbeec07dfef3 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 11:55:30 -0400 Subject: [PATCH 100/172] Change util func name --- src/jimgw/single_event/utils.py | 2 +- src/jimgw/transforms.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 705857fe..8721c176 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -163,7 +163,7 @@ def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: return M, eta -def Mc_q_to_eta(M_c: Float, q: Float) -> Float: +def q_to_eta(q: Float) -> Float: """ Transforming the chirp mass M_c and mass ratio q to the symmetric mass ratio eta. diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 5a889569..165432eb 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -9,7 +9,7 @@ from jimgw.single_event.utils import ( Mc_q_to_m1_m2, m1_m2_to_Mc_q, - Mc_q_to_eta, + q_to_eta, eta_to_q, ra_dec_to_zenith_azimuth, zenith_azimuth_to_ra_dec, @@ -426,7 +426,7 @@ def __init__( def named_transform(x): Mc = x[name_mapping[0][0]] q = x[name_mapping[0][1]] - eta = Mc_q_to_eta(Mc, q) + eta = q_to_eta(q) return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta} self.transform_func = named_transform From 1cb0d11c2f58c8413cc34f823ca7943c0f820f59 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 11:59:38 -0400 Subject: [PATCH 101/172] Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-class' into transform" This reverts commit b4f60527bf7ec4e002698e32458b7da659a88da5, reversing changes made to 8368d00bd9a5d3a41e670036bdcfcbeec07dfef3. --- src/jimgw/jim.py | 24 ++--- src/jimgw/transforms.py | 147 +++++++++++++++++++++++++++--- test/integration/test_GW150914.py | 17 ---- 3 files changed, 142 insertions(+), 46 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 39867e9f..3186c2c4 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from jimgw.base import LikelihoodBase -from jimgw.prior import Prior, trace_prior_parent +from jimgw.prior import Prior from jimgw.transforms import BijectiveTransform, NtoMTransform @@ -22,7 +22,6 @@ class Jim(object): # Name of parameters to sample from sample_transforms: list[BijectiveTransform] likelihood_transforms: list[NtoMTransform] - parameter_names: list[str] sampler: Sampler def __init__( @@ -38,18 +37,16 @@ def __init__( self.sample_transforms = sample_transforms self.likelihood_transforms = likelihood_transforms - self.parameter_names = prior.parameter_names if len(sample_transforms) == 0: - print("No sample transforms provided. Using prior parameters as sampling parameters") - else: - print("Using sample transforms") - for transform in sample_transforms: - self.parameter_names = transform.propagate_name(self.parameter_names) + print( + "No sample transforms provided. Using prior parameters as sampling parameters" + ) if len(likelihood_transforms) == 0: - print("No likelihood transforms provided. Using prior parameters as likelihood parameters") - + print( + "No likelihood transforms provided. Using prior parameters as likelihood parameters" + ) seed = kwargs.get("seed", 0) @@ -97,15 +94,12 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - named_params = jax.tree.map(lambda x:x[0], named_params) # This [0] should be consolidate - return self.likelihood.evaluate(named_params, data) + prior[0] # This prior [0] should be consolidate + return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) - for transform in self.sample_transforms: - initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[0] # This [0] should be consolidate + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T self.Sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index b832ccdc..165432eb 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,12 +1,21 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Callable import jax import jax.numpy as jnp -from chex import assert_rank from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped +from jimgw.single_event.utils import ( + Mc_q_to_m1_m2, + m1_m2_to_Mc_q, + q_to_eta, + eta_to_q, + ra_dec_to_zenith_azimuth, + zenith_azimuth_to_ra_dec, + euler_rotation, +) + class Transform(ABC): """ @@ -261,7 +270,6 @@ def __init__( @jaxtyped(typechecker=typechecker) class BoundToBound(BijectiveTransform): - """ Bound to bound transformation """ @@ -300,7 +308,7 @@ def __init__( for i in range(len(name_mapping[1])) } -@jaxtyped(typechecker=typechecker) + class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -315,13 +323,13 @@ def __init__( original_lower_bound: Float, original_upper_bound: Float, ): - + def logit(x): return jnp.log(x / (1 - x)) super().__init__(name_mapping) - self.original_lower_bound = jnp.atleast_1d(original_lower_bound) - self.original_upper_bound = jnp.atleast_1d(original_upper_bound) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound self.transform_func = lambda x: { name_mapping[1][i]: logit( @@ -331,17 +339,13 @@ def logit(x): for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: ( - self.original_upper_bound - self.original_lower_bound - ) - / ( - 1 - + jnp.exp(-x[name_mapping[1][i]]) - ) + name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) + / (1 + jnp.exp(-x[name_mapping[1][i]])) + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } + class SingleSidedUnboundTransform(BijectiveTransform): """ Unbound upper limit transformation @@ -368,6 +372,121 @@ def __init__( } +class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to component masses + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + m1, m2 = Mc_q_to_m1_m2(Mc, q) + return {name_mapping[1][0]: m1, name_mapping[1][1]: m2} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1 = x[name_mapping[1][0]] + m2 = x[name_mapping[1][1]] + Mc, q = m1_m2_to_Mc_q(m1, m2) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + +class ChirpMassMassRatioToChirpMassSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to chirp mass and symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + eta = q_to_eta(q) + return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta} + + self.transform_func = named_transform + + def named_inverse_transform(x): + Mc = x[name_mapping[1][0]] + eta = x[name_mapping[1][1]] + q = eta_to_q(Mc, eta) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + +class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): + """ + Transform sky frame to detector frame sky position + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + rotation: Float[Array, " 3 3"] + rotation_inv: Float[Array, " 3 3"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gmst: Float, + delta_x: Float, + ): + super().__init__(name_mapping) + + self.gmst = gmst + self.rotation = euler_rotation(delta_x) + self.rotation_inv = jnp.linalg.inv(self.rotation) + + def named_transform(x): + ra = x[name_mapping[0][0]] + dec = x[name_mapping[0][1]] + zenith, azimuth = ra_dec_to_zenith_azimuth( + ra, dec, self.gmst, self.rotation + ) + return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth} + + self.transform_func = named_transform + + def named_inverse_transform(x): + zenith = x[name_mapping[1][0]] + azimuth = x[name_mapping[1][1]] + ra, dec = zenith_azimuth_to_ra_dec( + zenith, azimuth, self.gmst, self.rotation_inv + ) + return {name_mapping[0][0]: ra, name_mapping[0][1]: dec} + + self.inverse_transform_func = named_inverse_transform + # class PowerLawTransform(UnivariateTransform): # """ diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index d82e8d7a..deb3fb98 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,7 +8,6 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD -from jimgw.transforms import BoundToUnbound from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -65,21 +64,6 @@ dec_prior, ] ) - -sample_transforms = [ - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25), - BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), - BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), - BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2), - BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) -] - likelihood = TransientLikelihoodFD( [H1, L1], waveform=RippleIMRPhenomD(), @@ -104,7 +88,6 @@ jim = Jim( likelihood, prior, - sample_transforms=sample_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, From 78e93c502ed24b77f74f2cbb8b662ca54f3c4d56 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 12:03:00 -0400 Subject: [PATCH 102/172] Merge --- src/jimgw/jim.py | 19 +++++++++++++++++-- src/jimgw/transforms.py | 5 +++-- test/integration/test_GW150914.py | 17 +++++++++++++++++ 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 3186c2c4..2a9cb952 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -22,6 +22,7 @@ class Jim(object): # Name of parameters to sample from sample_transforms: list[BijectiveTransform] likelihood_transforms: list[NtoMTransform] + parameter_names: list[str] sampler: Sampler def __init__( @@ -37,11 +38,16 @@ def __init__( self.sample_transforms = sample_transforms self.likelihood_transforms = likelihood_transforms + self.parameter_names = prior.parameter_names if len(sample_transforms) == 0: print( "No sample transforms provided. Using prior parameters as sampling parameters" ) + else: + print("Using sample transforms") + for transform in sample_transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) if len(likelihood_transforms) == 0: print( @@ -94,12 +100,21 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - return self.likelihood.evaluate(named_params, data) + prior + named_params = jax.tree.map( + lambda x: x[0], named_params + ) # This [0] should be consolidate + return ( + self.likelihood.evaluate(named_params, data) + prior[0] + ) # This prior [0] should be consolidate def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + for transform in self.sample_transforms: + initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[ + 0 + ] # This [0] should be consolidate self.Sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 165432eb..10c42c33 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -309,6 +309,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -328,8 +329,8 @@ def logit(x): return jnp.log(x / (1 - x)) super().__init__(name_mapping) - self.original_lower_bound = original_lower_bound - self.original_upper_bound = original_upper_bound + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.original_upper_bound = jnp.atleast_1d(original_upper_bound) self.transform_func = lambda x: { name_mapping[1][i]: logit( diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index deb3fb98..d82e8d7a 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,6 +8,7 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -64,6 +65,21 @@ dec_prior, ] ) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) +] + likelihood = TransientLikelihoodFD( [H1, L1], waveform=RippleIMRPhenomD(), @@ -88,6 +104,7 @@ jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, From 080bd8bb36b97f4a87ec2cbd7d84900e42d9b9e8 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 12:51:35 -0400 Subject: [PATCH 103/172] Modify integration test --- src/jimgw/transforms.py | 21 ++++----------------- test/integration/test_GW150914.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 10c42c33..3fc3cd1a 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -407,9 +407,9 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform -class ChirpMassMassRatioToChirpMassSymmetricMassRatioTransform(BijectiveTransform): +class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): """ - Transform chirp mass and mass ratio to chirp mass and symmetric mass ratio + Transform mass ratio to symmetric mass ratio Parameters ---------- @@ -424,21 +424,8 @@ def __init__( ): super().__init__(name_mapping) - def named_transform(x): - Mc = x[name_mapping[0][0]] - q = x[name_mapping[0][1]] - eta = q_to_eta(q) - return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta} - - self.transform_func = named_transform - - def named_inverse_transform(x): - Mc = x[name_mapping[1][0]] - eta = x[name_mapping[1][1]] - q = eta_to_q(Mc, eta) - return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} - - self.inverse_transform_func = named_inverse_transform + self.transform_func = lambda x: {name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]])} + self.inverse_transform_func = lambda x: {name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]])} class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index d82e8d7a..4cee35e9 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,7 +8,7 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD -from jimgw.transforms import BoundToUnbound +from jimgw.transforms import BoundToUnbound, MassRatioToSymmetricMassRatioTransform from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -34,10 +34,10 @@ L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -eta_prior = UniformPrior( +q_prior = UniformPrior( 0.125, - 0.25, - parameter_names=["eta"], # Need name transformation in likelihood to work + 1, + parameter_names=["q"], ) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) @@ -45,15 +45,15 @@ dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -iota_prior = CosinePrior(parameter_names=["iota"]) +iota_prior = SinePrior(parameter_names=["iota"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -dec_prior = SinePrior(parameter_names=["dec"]) +dec_prior = CosinePrior(parameter_names=["dec"]) prior = CombinePrior( [ Mc_prior, - eta_prior, + q_prior, s1z_prior, s2z_prior, dL_prior, @@ -68,7 +68,7 @@ sample_transforms = [ BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), @@ -80,6 +80,10 @@ BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) ] +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + likelihood = TransientLikelihoodFD( [H1, L1], waveform=RippleIMRPhenomD(), @@ -105,6 +109,7 @@ likelihood, prior, sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, From 4058d327418b3dbd9b742cb7af212a76eac8162f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 12:56:23 -0400 Subject: [PATCH 104/172] Reformat --- src/jimgw/transforms.py | 8 ++++++-- test/integration/test_GW150914.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 3fc3cd1a..0326b28c 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -424,8 +424,12 @@ def __init__( ): super().__init__(name_mapping) - self.transform_func = lambda x: {name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]])} - self.inverse_transform_func = lambda x: {name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]])} + self.transform_func = lambda x: { + name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]]) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]]) + } class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 4cee35e9..753eb6b1 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -36,7 +36,7 @@ Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) q_prior = UniformPrior( 0.125, - 1, + 1., parameter_names=["q"], ) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) @@ -68,7 +68,7 @@ sample_transforms = [ BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), From cdf771d1f3225045607601353a3f9e7ef3cfac54 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 13:03:35 -0400 Subject: [PATCH 105/172] Add typecheck --- src/jimgw/transforms.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 0326b28c..1bb17c18 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -174,6 +174,7 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: return y_copy +@jaxtyped(typechecker=typechecker) class ScaleTransform(BijectiveTransform): scale: Float @@ -194,6 +195,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class OffsetTransform(BijectiveTransform): offset: Float @@ -214,6 +216,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class LogitTransform(BijectiveTransform): """ Logit transform following @@ -242,6 +245,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -347,6 +351,7 @@ def logit(x): } +@jaxtyped(typechecker=typechecker) class SingleSidedUnboundTransform(BijectiveTransform): """ Unbound upper limit transformation @@ -373,6 +378,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): """ Transform chirp mass and mass ratio to component masses @@ -407,6 +413,7 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): """ Transform mass ratio to symmetric mass ratio @@ -432,6 +439,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): """ Transform sky frame to detector frame sky position From 02c5650a064fa749c2a63a9e2c885c230ca05523 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 13:21:35 -0400 Subject: [PATCH 106/172] minor typo --- src/jimgw/jim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2a9cb952..19f74606 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -70,7 +70,7 @@ def __init__( self.prior.n_dim, num_layers, hidden_size, num_bins, subkey ) - self.Sampler = Sampler( + self.sampler = Sampler( self.prior.n_dim, rng_key, None, # type: ignore From ce7ac34d2e8094ab1cfaa31fafa657c1887fcf45 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 13:38:40 -0400 Subject: [PATCH 107/172] Rename sampler --- src/jimgw/jim.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 19f74606..f21aff42 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -109,13 +109,13 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: - initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) + initial_guess_named = self.prior.sample(key, self.sampler.n_chains) for transform in self.sample_transforms: initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[ 0 ] # This [0] should be consolidate - self.Sampler.sample(initial_guess, None) # type: ignore + self.sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( self, @@ -148,8 +148,8 @@ def print_summary(self, transform: bool = True): """ - train_summary = self.Sampler.get_sampler_state(training=True) - production_summary = self.Sampler.get_sampler_state(training=False) + train_summary = self.sampler.get_sampler_state(training=True) + production_summary = self.sampler.get_sampler_state(training=False) training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T training_chain = self.prior.add_name(training_chain) @@ -215,9 +215,9 @@ def get_samples(self, training: bool = False) -> dict: """ if training: - chains = self.Sampler.get_sampler_state(training=True)["chains"] + chains = self.sampler.get_sampler_state(training=True)["chains"] else: - chains = self.Sampler.get_sampler_state(training=False)["chains"] + chains = self.sampler.get_sampler_state(training=False)["chains"] chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1))) return chains From 7d44aa4ba9763a31a8310b40b3eac29d47e47893 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 13:40:53 -0400 Subject: [PATCH 108/172] Fix test --- test/integration/test_GW150914.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 753eb6b1..6193032a 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -74,10 +74,10 @@ BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) ] likelihood_transforms = [ From 45be9ae70c570b681831879467feb6f969d7062f Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:31:56 -0400 Subject: [PATCH 109/172] Revert "Fixed PowerLaw prior" --- src/jimgw/jim.py | 7 +- src/jimgw/prior.py | 4 +- src/jimgw/transforms.py | 142 ++++++++++++++++++---------------------- test/unit/test_prior.py | 18 +++-- 4 files changed, 78 insertions(+), 93 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index ce097d44..39867e9f 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from jimgw.base import LikelihoodBase -from jimgw.prior import Prior +from jimgw.prior import Prior, trace_prior_parent from jimgw.transforms import BijectiveTransform, NtoMTransform @@ -48,9 +48,8 @@ def __init__( self.parameter_names = transform.propagate_name(self.parameter_names) if len(likelihood_transforms) == 0: - print( - "No likelihood transforms provided. Using prior parameters as likelihood parameters" - ) + print("No likelihood transforms provided. Using prior parameters as likelihood parameters") + seed = kwargs.get("seed", 0) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index aff1ebf5..13ab622f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,8 +12,8 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - PowerLawTransform, - ParetoTransform, + # PowerLawTransform, + # ParetoTransform, ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 75d0420a..b832ccdc 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,8 +1,9 @@ -from abc import ABC +from abc import ABC, abstractmethod from typing import Callable import jax import jax.numpy as jnp +from chex import assert_rank from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped @@ -260,6 +261,7 @@ def __init__( @jaxtyped(typechecker=typechecker) class BoundToBound(BijectiveTransform): + """ Bound to bound transformation """ @@ -298,7 +300,6 @@ def __init__( for i in range(len(name_mapping[1])) } - @jaxtyped(typechecker=typechecker) class BoundToUnbound(BijectiveTransform): """ @@ -314,7 +315,7 @@ def __init__( original_lower_bound: Float, original_upper_bound: Float, ): - + def logit(x): return jnp.log(x / (1 - x)) @@ -330,13 +331,17 @@ def logit(x): for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) - / (1 + jnp.exp(-x[name_mapping[1][i]])) + name_mapping[0][i]: ( + self.original_upper_bound - self.original_lower_bound + ) + / ( + 1 + + jnp.exp(-x[name_mapping[1][i]]) + ) + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } - class SingleSidedUnboundTransform(BijectiveTransform): """ Unbound upper limit transformation @@ -363,78 +368,55 @@ def __init__( } -class PowerLawTransform(BijectiveTransform): - """ - PowerLaw transformation - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - xmin: Float - xmax: Float - alpha: Float - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - alpha: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.alpha = alpha - self.transform_func = lambda x: { - name_mapping[1][i]: ( - self.xmin ** (1.0 + self.alpha) - + x[name_mapping[0][i]] - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) - ** (1.0 / (1.0 + self.alpha)) - for i in range(len(name_mapping[0])) - } - self.inverse_transform_func = lambda x: { - name_mapping[0][i]: ( - ( - x[name_mapping[1][i]] ** (1.0 + self.alpha) - - self.xmin ** (1.0 + self.alpha) - ) - / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) - for i in range(len(name_mapping[1])) - } - - -class ParetoTransform(BijectiveTransform): - """ - Pareto transformation: Power law when alpha = -1 - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.transform_func = lambda x: { - name_mapping[1][i]: self.xmin - * jnp.exp(x[name_mapping[0][i]] * jnp.log(self.xmax / self.xmin)) - for i in range(len(name_mapping[0])) - } - self.inverse_transform_func = lambda x: { - name_mapping[0][i]: ( - jnp.log(x[name_mapping[1][i]] / self.xmin) - / jnp.log(self.xmax / self.xmin) - ) - for i in range(len(name_mapping[1])) - } +# class PowerLawTransform(UnivariateTransform): +# """ +# PowerLaw transformation +# Parameters +# ---------- +# name_mapping : tuple[list[str], list[str]] +# The name mapping between the input and output dictionary. +# """ + +# xmin: Float +# xmax: Float +# alpha: Float + +# def __init__( +# self, +# name_mapping: tuple[list[str], list[str]], +# xmin: Float, +# xmax: Float, +# alpha: Float, +# ): +# super().__init__(name_mapping) +# self.xmin = xmin +# self.xmax = xmax +# self.alpha = alpha +# self.transform_func = lambda x: ( +# self.xmin ** (1.0 + self.alpha) +# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) + + +# class ParetoTransform(UnivariateTransform): +# """ +# Pareto transformation: Power law when alpha = -1 +# Parameters +# ---------- +# name_mapping : tuple[list[str], list[str]] +# The name mapping between the input and output dictionary. +# """ + +# def __init__( +# self, +# name_mapping: tuple[list[str], list[str]], +# xmin: Float, +# xmax: Float, +# ): +# super().__init__(name_mapping) +# self.xmin = xmin +# self.xmax = xmax +# self.transform_func = lambda x: self.xmin * jnp.exp( +# x * jnp.log(self.xmax / self.xmin) +# ) diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 1cbed508..20d71de4 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -96,16 +96,20 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = p.sample(jax.random.PRNGKey(0), 10000) - log_p = jax.vmap(p.log_prob, [0])(samples) - assert jnp.all(jnp.isfinite(log_p)) + samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base'] + base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples}) + assert jnp.all(jnp.isfinite(base_log_p)) # Check that the log_prob is correct in the support - samples = p.sample(jax.random.PRNGKey(0), 10000) - log_prob = jax.vmap(p.log_prob)(samples) - standard_log_prob = powerlaw_log_pdf(samples['x'], alpha, xmin, xmax) + samples = jnp.linspace(-10.0, 10.0, 1000) + transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x'] + # cut off the samples that are outside the support + samples = samples[transformed_samples >= xmin] + transformed_samples = transformed_samples[transformed_samples >= xmin] + samples = samples[transformed_samples <= xmax] + transformed_samples = transformed_samples[transformed_samples <= xmax] # log pdf of powerlaw - assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4) + assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) # Test Pareto Transform func(-1.0) From e9288c8116b2209cdc1de477bac5fc0742cc9482 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 14:43:14 -0400 Subject: [PATCH 110/172] Fix BoundToUnbound transform --- src/jimgw/jim.py | 11 +++-------- src/jimgw/transforms.py | 6 +++--- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index f21aff42..4ccf0451 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -100,21 +100,16 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - named_params = jax.tree.map( - lambda x: x[0], named_params - ) # This [0] should be consolidate return ( - self.likelihood.evaluate(named_params, data) + prior[0] - ) # This prior [0] should be consolidate + self.likelihood.evaluate(named_params, data) + prior + ) def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.sampler.n_chains) for transform in self.sample_transforms: initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[ - 0 - ] # This [0] should be consolidate + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T self.sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 1bb17c18..c6c4332f 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -338,13 +338,13 @@ def logit(x): self.transform_func = lambda x: { name_mapping[1][i]: logit( - (x[name_mapping[0][i]] - self.original_lower_bound) - / (self.original_upper_bound - self.original_lower_bound) + (x[name_mapping[0][i]] - self.original_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) ) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) + name_mapping[0][i]: (self.original_upper_bound[i] - self.original_lower_bound[i]) / (1 + jnp.exp(-x[name_mapping[1][i]])) + self.original_lower_bound[i] for i in range(len(name_mapping[1])) From d5a34f25acaac917df94e8e90ad85574116f4d50 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:03:55 -0400 Subject: [PATCH 111/172] Updated prior.py and transforms.py --- src/jimgw/prior.py | 4 +- src/jimgw/transforms.py | 125 ++++++++++++++++++++++++---------------- 2 files changed, 76 insertions(+), 53 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 13ab622f..aff1ebf5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,8 +12,8 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - # PowerLawTransform, - # ParetoTransform, + PowerLawTransform, + ParetoTransform, ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index b832ccdc..5a154f96 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -368,55 +368,78 @@ def __init__( } +class PowerLawTransform(BijectiveTransform): + """ + PowerLaw transformation + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + xmin: Float + xmax: Float + alpha: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + alpha: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.alpha = alpha + self.transform_func = lambda x: { + name_mapping[1][i]: ( + self.xmin ** (1.0 + self.alpha) + + x[name_mapping[0][i]] + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + ** (1.0 / (1.0 + self.alpha)) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + ( + x[name_mapping[1][i]] ** (1.0 + self.alpha) + - self.xmin ** (1.0 + self.alpha) + ) + / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + for i in range(len(name_mapping[1])) + } -# class PowerLawTransform(UnivariateTransform): -# """ -# PowerLaw transformation -# Parameters -# ---------- -# name_mapping : tuple[list[str], list[str]] -# The name mapping between the input and output dictionary. -# """ - -# xmin: Float -# xmax: Float -# alpha: Float - -# def __init__( -# self, -# name_mapping: tuple[list[str], list[str]], -# xmin: Float, -# xmax: Float, -# alpha: Float, -# ): -# super().__init__(name_mapping) -# self.xmin = xmin -# self.xmax = xmax -# self.alpha = alpha -# self.transform_func = lambda x: ( -# self.xmin ** (1.0 + self.alpha) -# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) -# ) ** (1.0 / (1.0 + self.alpha)) - - -# class ParetoTransform(UnivariateTransform): -# """ -# Pareto transformation: Power law when alpha = -1 -# Parameters -# ---------- -# name_mapping : tuple[list[str], list[str]] -# The name mapping between the input and output dictionary. -# """ - -# def __init__( -# self, -# name_mapping: tuple[list[str], list[str]], -# xmin: Float, -# xmax: Float, -# ): -# super().__init__(name_mapping) -# self.xmin = xmin -# self.xmax = xmax -# self.transform_func = lambda x: self.xmin * jnp.exp( -# x * jnp.log(self.xmax / self.xmin) -# ) + +class ParetoTransform(BijectiveTransform): + """ + Pareto transformation: Power law when alpha = -1 + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.transform_func = lambda x: { + name_mapping[1][i]: self.xmin + * jnp.exp(x[name_mapping[0][i]] * jnp.log(self.xmax / self.xmin)) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + jnp.log(x[name_mapping[1][i]] / self.xmin) + / jnp.log(self.xmax / self.xmin) + ) + for i in range(len(name_mapping[1])) + } From 34f3d260ac937fd4bd763b16f084f7fc99437405 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:07:42 -0400 Subject: [PATCH 112/172] Updated test_prior.py --- test/unit/test_prior.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 20d71de4..26f89179 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -79,6 +79,7 @@ def test_uniform_sphere(self): log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) + def test_power_law(self): def powerlaw_log_pdf(x, alpha, xmin, xmax): if alpha == -1.0: @@ -96,20 +97,16 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base'] - base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples}) - assert jnp.all(jnp.isfinite(base_log_p)) + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_p = jax.vmap(p.log_prob, [0])(samples) + assert jnp.all(jnp.isfinite(log_p)) # Check that the log_prob is correct in the support - samples = jnp.linspace(-10.0, 10.0, 1000) - transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x'] - # cut off the samples that are outside the support - samples = samples[transformed_samples >= xmin] - transformed_samples = transformed_samples[transformed_samples >= xmin] - samples = samples[transformed_samples <= xmax] - transformed_samples = transformed_samples[transformed_samples <= xmax] + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + standard_log_prob = powerlaw_log_pdf(samples['x'], alpha, xmin, xmax) # log pdf of powerlaw - assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) + assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4) # Test Pareto Transform func(-1.0) @@ -120,4 +117,4 @@ def func(alpha): func(alpha_val) negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] for alpha_val in negative_alpha: - func(alpha_val) + func(alpha_val) \ No newline at end of file From fe500a792a04a4409874b0564392643c6678033b Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 15:19:06 -0400 Subject: [PATCH 113/172] Use ifos list --- test/integration/test_GW150914.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 6193032a..9adda574 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -28,10 +28,10 @@ fmin = 20.0 fmax = 1024.0 -ifos = ["H1", "L1"] +ifos = [H1, L1] -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) q_prior = UniformPrior( @@ -85,7 +85,7 @@ ] likelihood = TransientLikelihoodFD( - [H1, L1], + ifos, waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, From 0d28520440f95697a36fa166841e83cd3e4b18aa Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 17:13:01 -0400 Subject: [PATCH 114/172] Fix jim summary and get_samples --- src/jimgw/jim.py | 64 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 4ccf0451..91a358a6 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -137,7 +137,7 @@ def negative_posterior(x: Float[Array, " n_dim"]): best_fit = optimizer.get_result()[0] return best_fit - def print_summary(self, transform: bool = True): + def print_summary(self): """ Generate summary of the run @@ -146,19 +146,45 @@ def print_summary(self, transform: bool = True): train_summary = self.sampler.get_sampler_state(training=True) production_summary = self.sampler.get_sampler_state(training=False) - training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T - training_chain = self.prior.add_name(training_chain) - if transform: - training_chain = self.prior.transform(training_chain) + training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(training_chain[0]) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in training_chain[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + training_chain = transformed_chain + else: + training_chain = self.add_name(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T - production_chain = self.prior.add_name(production_chain) - if transform: - production_chain = self.prior.transform(production_chain) + production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(production_chain[0]) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in production_chain[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + production_chain = transformed_chain + else: + production_chain = self.add_name(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -214,8 +240,24 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.sampler.get_sampler_state(training=False)["chains"] - chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1))) - return chains + # Need rewrite to output chains instead of flattened samples + chains = chains.reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(chains[0]) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in chains[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + return transformed_chain + else: + return self.add_name(chains) def plot(self): pass From f7e3fe882c44914eda36b12fe6a633d153dd4c0a Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 17:30:15 -0400 Subject: [PATCH 115/172] Fix jim output functions --- src/jimgw/jim.py | 24 ++++++++++++++---------- src/jimgw/transforms.py | 8 ++++---- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 91a358a6..74f65efc 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -151,13 +151,13 @@ def print_summary(self): transformed_chain = {} named_sample = self.add_name(training_chain[0]) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key] = [value] for sample in training_chain[1:]: named_sample = self.add_name(sample) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key].append(value) training_chain = transformed_chain @@ -173,13 +173,13 @@ def print_summary(self): transformed_chain = {} named_sample = self.add_name(production_chain[0]) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key] = [value] for sample in production_chain[1:]: named_sample = self.add_name(sample) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key].append(value) production_chain = transformed_chain @@ -192,7 +192,7 @@ def print_summary(self): print("Training summary") print("=" * 10) for key, value in training_chain.items(): - print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") + print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}") print( f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" ) @@ -209,7 +209,7 @@ def print_summary(self): print("Production summary") print("=" * 10) for key, value in production_chain.items(): - print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") + print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}") print( f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" ) @@ -246,18 +246,22 @@ def get_samples(self, training: bool = False) -> dict: transformed_chain = {} named_sample = self.add_name(chains[0]) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key] = [value] for sample in chains[1:]: named_sample = self.add_name(sample) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key].append(value) - return transformed_chain + output = transformed_chain else: - return self.add_name(chains) + output = self.add_name(chains) + + for key in output.keys(): + output[key] = jnp.array(output[key]) + return output def plot(self): pass diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index c6c4332f..9aea6503 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -115,7 +115,7 @@ class BijectiveTransform(NtoNTransform): inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def inverse(self, y: dict[str, Float]) -> dict[str, Float]: + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Inverse transform the input y to original coordinate x. @@ -128,6 +128,8 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: ------- x : dict[str, Float] The original dictionary. + log_det : Float + The log Jacobian determinant. """ y_copy = y.copy() transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) @@ -145,7 +147,7 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: ) return y_copy, jacobian - def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + def backward(self, y: dict[str, Float]) -> dict[str, Float]: """ Pull back the input y to original coordinate x and return the log Jacobian determinant. @@ -158,8 +160,6 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: ------- x : dict[str, Float] The original dictionary. - log_det : Float - The log Jacobian determinant. """ y_copy = y.copy() output_params = self.inverse_transform_func(y_copy) From 68bef542d34706429ebc1e3a6fd894a9c61e6c88 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 18:14:23 -0400 Subject: [PATCH 116/172] Modify Transform --- src/jimgw/transforms.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 9aea6503..a963b3a7 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -5,7 +5,9 @@ import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped +from astropy.time import Time +from jimgw.single_event.detector import GroundBased2G from jimgw.single_event.utils import ( Mc_q_to_m1_m2, m1_m2_to_Mc_q, @@ -395,20 +397,17 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) + assert "M_c" in name_mapping[0] and "q" in name_mapping[0] and "m_1" in name_mapping[1] and "m_2" in name_mapping[1] def named_transform(x): - Mc = x[name_mapping[0][0]] - q = x[name_mapping[0][1]] - m1, m2 = Mc_q_to_m1_m2(Mc, q) - return {name_mapping[1][0]: m1, name_mapping[1][1]: m2} + m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} self.transform_func = named_transform def named_inverse_transform(x): - m1 = x[name_mapping[1][0]] - m2 = x[name_mapping[1][1]] - Mc, q = m1_m2_to_Mc_q(m1, m2) - return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) + return {"M_c": Mc, "q": q} self.inverse_transform_func = named_inverse_transform @@ -458,32 +457,31 @@ class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], - gmst: Float, - delta_x: Float, + gps_time: Float, + ifos: GroundBased2G, ): super().__init__(name_mapping) - self.gmst = gmst + self.gmst = Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + delta_x = ifos[0].vertex - ifos[1].vertex self.rotation = euler_rotation(delta_x) self.rotation_inv = jnp.linalg.inv(self.rotation) + + assert "ra" in name_mapping[0] and "dec" in name_mapping[0] and "zenith" in name_mapping[1] and "azimuth" in name_mapping[1] def named_transform(x): - ra = x[name_mapping[0][0]] - dec = x[name_mapping[0][1]] zenith, azimuth = ra_dec_to_zenith_azimuth( - ra, dec, self.gmst, self.rotation + x["ra"], x["dec"], self.gmst, self.rotation ) - return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth} + return {"zenith": zenith, "azimuth": azimuth} self.transform_func = named_transform def named_inverse_transform(x): - zenith = x[name_mapping[1][0]] - azimuth = x[name_mapping[1][1]] ra, dec = zenith_azimuth_to_ra_dec( - zenith, azimuth, self.gmst, self.rotation_inv + x["zenith"], x["azimuth"], self.gmst, self.rotation_inv ) - return {name_mapping[0][0]: ra, name_mapping[0][1]: dec} + return {"ra": ra, "dec": dec} self.inverse_transform_func = named_inverse_transform From 87593e1fe370fb752293ec3e3e34117c12f2ccbb Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 18:59:11 -0400 Subject: [PATCH 117/172] Fix jim output --- src/jimgw/jim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 74f65efc..9b16c424 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -146,7 +146,7 @@ def print_summary(self): train_summary = self.sampler.get_sampler_state(training=True) production_summary = self.sampler.get_sampler_state(training=False) - training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names)).T + training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names)) if self.sample_transforms: transformed_chain = {} named_sample = self.add_name(training_chain[0]) @@ -168,7 +168,7 @@ def print_summary(self): training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names)).T + production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names)) if self.sample_transforms: transformed_chain = {} named_sample = self.add_name(production_chain[0]) @@ -241,7 +241,7 @@ def get_samples(self, training: bool = False) -> dict: chains = self.sampler.get_sampler_state(training=False)["chains"] # Need rewrite to output chains instead of flattened samples - chains = chains.reshape(-1, len(self.parameter_names)).T + chains = chains.reshape(-1, len(self.parameter_names)) if self.sample_transforms: transformed_chain = {} named_sample = self.add_name(chains[0]) From 730fe31e718a4c498d93dee90c72c27c455c9a6a Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 19:25:51 -0400 Subject: [PATCH 118/172] Add comment --- src/jimgw/jim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 9b16c424..274d5597 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -148,6 +148,7 @@ def print_summary(self): training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names)) if self.sample_transforms: + # Need rewrite to vectorize transformed_chain = {} named_sample = self.add_name(training_chain[0]) for transform in self.sample_transforms: @@ -170,6 +171,7 @@ def print_summary(self): production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names)) if self.sample_transforms: + # Need rewrite to vectorize transformed_chain = {} named_sample = self.add_name(production_chain[0]) for transform in self.sample_transforms: @@ -240,7 +242,7 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.sampler.get_sampler_state(training=False)["chains"] - # Need rewrite to output chains instead of flattened samples + # Need rewrite to output chains instead of flattened samples and vectorize chains = chains.reshape(-1, len(self.parameter_names)) if self.sample_transforms: transformed_chain = {} From ed727afeef73eb1c51b3400c30a2f63eec95464a Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:16:12 -0400 Subject: [PATCH 119/172] Updated test_prior.py --- test/unit/test_prior.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 26f89179..c68b4dc4 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -30,7 +30,7 @@ def test_uniform(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is correct in the support - samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin))) @@ -40,7 +40,7 @@ def test_sine(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite - samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support @@ -48,7 +48,7 @@ def test_sine(self): y = jax.vmap(p.base_prior.base_prior.transform)(x) y = jax.vmap(p.base_prior.transform)(y) y = jax.vmap(p.transform)(y) - assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.sin(y['x'])/2.0)) + assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.sin(y['x'])/2.0)) def test_cosine(self): p = CosinePrior(["x"]) @@ -56,14 +56,14 @@ def test_cosine(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite - samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) y = jax.vmap(p.base_prior.transform)(x) y = jax.vmap(p.transform)(y) - assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.cos(y['x'])/2.0)) + assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.cos(y['x'])/2.0)) def test_uniform_sphere(self): p = UniformSpherePrior(["x"]) From 9d191dacd1ccc3088a8b936006cf38280a832569 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:22:41 -0400 Subject: [PATCH 120/172] Updated test_prior.py --- test/unit/test_prior.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index c68b4dc4..c42d76be 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -30,7 +30,6 @@ def test_uniform(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is correct in the support - samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin))) @@ -40,7 +39,6 @@ def test_sine(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite - samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support @@ -56,7 +54,6 @@ def test_cosine(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite - samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support @@ -73,9 +70,6 @@ def test_uniform_sphere(self): assert jnp.all(jnp.isfinite(samples['x_theta'])) assert jnp.all(jnp.isfinite(samples['x_phi'])) # Check that the log_prob is finite - samples = {} - for i in range(3): - samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000)) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) @@ -97,14 +91,12 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = p.sample(jax.random.PRNGKey(0), 10000) - log_p = jax.vmap(p.log_prob, [0])(samples) + log_p = jax.vmap(p.log_prob, [0])(powerlaw_samples) assert jnp.all(jnp.isfinite(log_p)) # Check that the log_prob is correct in the support - samples = p.sample(jax.random.PRNGKey(0), 10000) - log_prob = jax.vmap(p.log_prob)(samples) - standard_log_prob = powerlaw_log_pdf(samples['x'], alpha, xmin, xmax) + log_prob = jax.vmap(p.log_prob)(powerlaw_samples) + standard_log_prob = powerlaw_log_pdf(powerlaw_samples['x'], alpha, xmin, xmax) # log pdf of powerlaw assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4) From 5a5ff2f345ea147f9995b8f5986cc1c6b03996ea Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 10:53:50 -0400 Subject: [PATCH 121/172] Add sky position transform --- src/jimgw/jim.py | 2 +- test/integration/test_GW150914.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 274d5597..2cc463b2 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -94,7 +94,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) transform_jacobian = 0.0 - for transform in self.sample_transforms: + for transform in reversed(self.sample_transforms): named_params, jacobian = transform.inverse(named_params) transform_jacobian += jacobian prior = self.prior.log_prob(named_params) + transform_jacobian diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 9adda574..f29b8ad6 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,7 +8,7 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD -from jimgw.transforms import BoundToUnbound, MassRatioToSymmetricMassRatioTransform +from jimgw.transforms import BoundToUnbound, MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -76,8 +76,9 @@ BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ From bd915efaaaef6e6e7e10a4af39908f47d1d6c50b Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 2 Aug 2024 11:03:57 -0400 Subject: [PATCH 122/172] Add powerlaw transform back --- src/jimgw/transforms.py | 45 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 1f9e1185..033e4882 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -485,6 +485,51 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +class PowerLawTransform(BijectiveTransform): + """ + PowerLaw transformation + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + xmin: Float + xmax: Float + alpha: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + alpha: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.alpha = alpha + self.transform_func = lambda x: { + name_mapping[1][i]: ( + self.xmin ** (1.0 + self.alpha) + + x[name_mapping[0][i]] + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + ** (1.0 / (1.0 + self.alpha)) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + ( + x[name_mapping[1][i]] ** (1.0 + self.alpha) + - self.xmin ** (1.0 + self.alpha) + ) + / (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + for i in range(len(name_mapping[1])) + } + + class ParetoTransform(BijectiveTransform): """ Pareto transformation: Power law when alpha = -1 From ba86a65701994a2cbbc773888e41b8bf305c9768 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 11:47:12 -0400 Subject: [PATCH 123/172] Move single_event prior and transform --- src/jimgw/jim.py | 16 +- src/jimgw/prior.py | 219 --------------------------- src/jimgw/single_event/prior.py | 198 ++++++++++++++++++++++++ src/jimgw/single_event/transforms.py | 124 +++++++++++++++ src/jimgw/single_event/utils.py | 12 +- src/jimgw/transforms.py | 120 +-------------- 6 files changed, 341 insertions(+), 348 deletions(-) create mode 100644 src/jimgw/single_event/prior.py create mode 100644 src/jimgw/single_event/transforms.py diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2cc463b2..7d86fecf 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -100,9 +100,7 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - return ( - self.likelihood.evaluate(named_params, data) + prior - ) + return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: @@ -169,7 +167,9 @@ def print_summary(self): training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names)) + production_chain = production_summary["chains"].reshape( + -1, len(self.parameter_names) + ) if self.sample_transforms: # Need rewrite to vectorize transformed_chain = {} @@ -194,7 +194,9 @@ def print_summary(self): print("Training summary") print("=" * 10) for key, value in training_chain.items(): - print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}") + print( + f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}" + ) print( f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" ) @@ -211,7 +213,9 @@ def print_summary(self): print("Production summary") print("=" * 10) for key, value in production_chain.items(): - print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}") + print( + f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}" + ) print( f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" ) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e4415618..5227ffa5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -398,227 +398,8 @@ def __init__( ) -@jaxtyped(typechecker=typechecker) -class UniformComponentChirpMassPrior(PowerLawPrior): - """ - A prior in the range [xmin, xmax) for chirp mass which assumes the - component masses to be uniformly distributed. - - p(M_c) ~ M_c - """ - - def __repr__(self): - return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" - - def __init__(self, xmin: float, xmax: float): - super().__init__(xmin, xmax, 1.0, ["M_c"]) - - -def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: - if prior.composite: - if isinstance(prior.base_prior, list): - for subprior in prior.base_prior: - output = trace_prior_parent(subprior, output) - elif isinstance(prior.base_prior, Prior): - output = trace_prior_parent(prior.base_prior, output) - else: - output.append(prior) - - return output - - # ====================== Things below may need rework ====================== - -# @jaxtyped(typechecker=typechecker) -# class AlignedSpin(Prior): -# """ -# Prior distribution for the aligned (z) component of the spin. - -# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] -# with its orientation uniform on a sphere - -# p(chi) = -log(|chi| / amax) / 2 / amax - -# This is useful when comparing results between an aligned-spin run and -# a precessing spin run. - -# See (A7) of https://arxiv.org/abs/1805.10457. -# """ - -# amax: Float = 0.99 -# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) -# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - -# def __repr__(self): -# return f"Alignedspin(amax={self.amax}, naming={self.naming})" - -# def __init__( -# self, -# amax: Float, -# naming: list[str], -# transforms: dict[str, tuple[str, Callable]] = {}, -# **kwargs, -# ): -# super().__init__(naming, transforms) -# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" -# self.amax = amax - -# # build the interpolation table for the ppf of the one-sided distribution -# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) -# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax -# self.chi_axis = chi_axis -# self.cdf_vals = cdf_vals - -# @property -# def xmin(self): -# return -self.amax - -# @property -# def xmax(self): -# return self.amax - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# """ -# Sample from the Alignedspin distribution. - -# for chi > 0; -# p(chi) = -log(chi / amax) / amax # halved normalization constant -# cdf(chi) = -chi * (log(chi / amax) - 1) / amax - -# Since there is a pole at chi=0, we will sample with the following steps -# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise -# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) -# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) -# 3. Map the quantile to chi via the ppf by checking against the table -# built during the initialization -# 4. add back the sign - -# Parameters -# ---------- -# rng_key : PRNGKeyArray -# A random key to use for sampling. -# n_samples : int -# The number of samples to draw. - -# Returns -# ------- -# samples : dict -# Samples from the distribution. The keys are the names of the parameters. - -# """ -# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) -# # 1. calculate the sign of chi from the q_samples -# sign_samples = jnp.where( -# q_samples >= 0.5, -# jnp.zeros_like(q_samples) + 1.0, -# jnp.zeros_like(q_samples) - 1.0, -# ) -# # 2. remap q_samples -# q_samples = jnp.where( -# q_samples >= 0.5, -# 2 * (q_samples - 0.5), -# 2 * (0.5 - q_samples), -# ) -# # 3. map the quantile to chi via interpolation -# samples = jnp.interp( -# q_samples, -# self.cdf_vals, -# self.chi_axis, -# ) -# # 4. add back the sign -# samples *= sign_samples - -# return self.add_name(samples[None]) - -# def log_prob(self, x: dict[str, Float]) -> Float: -# variable = x[self.naming[0]] -# log_p = jnp.where( -# (variable >= self.amax) | (variable <= -self.amax), -# jnp.zeros_like(variable) - jnp.inf, -# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), -# ) -# return log_p - - -# @jaxtyped(typechecker=typechecker) -# class EarthFrame(Prior): -# """ -# Prior distribution for sky location in Earth frame. -# """ - -# ifos: list = field(default_factory=list) -# gmst: float = 0.0 -# delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) - -# def __repr__(self): -# return f"EarthFrame(naming={self.naming})" - -# def __init__(self, gps: Float, ifos: list, **kwargs): -# self.naming = ["zenith", "azimuth"] -# if len(ifos) < 2: -# return ValueError( -# "At least two detectors are needed to define the Earth frame" -# ) -# elif isinstance(ifos[0], str): -# self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] -# elif isinstance(ifos[0], GroundBased2G): -# self.ifos = ifos[:1] -# else: -# return ValueError( -# "ifos should be a list of detector names or GroundBased2G objects" -# ) -# self.gmst = float( -# Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad -# ) -# self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - -# self.transforms = { -# "azimuth": ( -# "ra", -# lambda params: zenith_azimuth_to_ra_dec( -# params["zenith"], -# params["azimuth"], -# gmst=self.gmst, -# delta_x=self.delta_x, -# )[0], -# ), -# "zenith": ( -# "dec", -# lambda params: zenith_azimuth_to_ra_dec( -# params["zenith"], -# params["azimuth"], -# gmst=self.gmst, -# delta_x=self.delta_x, -# )[1], -# ), -# } - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# rng_keys = jax.random.split(rng_key, 2) -# zenith = jnp.arccos( -# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) -# ) -# azimuth = jax.random.uniform( -# rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi -# ) -# return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) - -# def log_prob(self, x: dict[str, Float]) -> Float: -# zenith = x["zenith"] -# azimuth = x["azimuth"] -# output = jnp.where( -# (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), -# jnp.zeros_like(0) - jnp.inf, -# jnp.zeros_like(0), -# ) -# return output + jnp.log(jnp.sin(zenith)) - - # @jaxtyped(typechecker=typechecker) # class Exponential(Prior): # """ diff --git a/src/jimgw/single_event/prior.py b/src/jimgw/single_event/prior.py new file mode 100644 index 00000000..e0173482 --- /dev/null +++ b/src/jimgw/single_event/prior.py @@ -0,0 +1,198 @@ +from dataclasses import field + +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from flowMC.nfmodel.base import Distribution +from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped + +from jimgw.prior import Prior, CombinePrior, UniformPrior, PowerLawPrior, SinePrior, CosinePrior + +@jaxtyped(typechecker=typechecker) +class UniformSpherePrior(CombinePrior): + + def __repr__(self): + return f"UniformSpherePrior(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str], **kwargs): + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" + self.parameter_names = [ + f"{self.parameter_names[0]}_mag", + f"{self.parameter_names[0]}_theta", + f"{self.parameter_names[0]}_phi", + ] + super().__init__( + [ + UniformPrior(0.0, 1.0, [self.parameter_names[0]]), + SinePrior([self.parameter_names[1]]), + UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + ] + ) + +@jaxtyped(typechecker=typechecker) +class UniformComponentMassPrior(CombinePrior): + """ + A prior in the range [xmin, xmax) for component masses which assumes the + component masses to be uniformly distributed. + """ + + def __repr__(self): + return f"UniformComponentMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float): + self.parameter_names = ["m_1", "m_2"] + super().__init__( + [ + UniformPrior(xmin, xmax, ["m_1"]), + UniformPrior(xmin, xmax, ["m_2"]), + ] + ) + + def log_prob(self, z: dict[str, Float]) -> Float: + output = super().log_prob(z) + output += jnp.log(2.) + +@jaxtyped(typechecker=typechecker) +class UniformComponentChirpMassPrior(PowerLawPrior): + """ + A prior in the range [xmin, xmax) for chirp mass which assumes the + component masses to be uniformly distributed. + + p(M_c) ~ M_c + """ + + def __repr__(self): + return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float): + super().__init__(xmin, xmax, 1.0, ["M_c"]) + + +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output + + +# ====================== Things below may need rework ====================== + + +# @jaxtyped(typechecker=typechecker) +# class AlignedSpin(Prior): +# """ +# Prior distribution for the aligned (z) component of the spin. + +# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] +# with its orientation uniform on a sphere + +# p(chi) = -log(|chi| / amax) / 2 / amax + +# This is useful when comparing results between an aligned-spin run and +# a precessing spin run. + +# See (A7) of https://arxiv.org/abs/1805.10457. +# """ + +# amax: Float = 0.99 +# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) +# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + +# def __repr__(self): +# return f"Alignedspin(amax={self.amax}, naming={self.naming})" + +# def __init__( +# self, +# amax: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" +# self.amax = amax + +# # build the interpolation table for the ppf of the one-sided distribution +# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) +# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax +# self.chi_axis = chi_axis +# self.cdf_vals = cdf_vals + +# @property +# def xmin(self): +# return -self.amax + +# @property +# def xmax(self): +# return self.amax + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from the Alignedspin distribution. + +# for chi > 0; +# p(chi) = -log(chi / amax) / amax # halved normalization constant +# cdf(chi) = -chi * (log(chi / amax) - 1) / amax + +# Since there is a pole at chi=0, we will sample with the following steps +# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise +# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) +# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) +# 3. Map the quantile to chi via the ppf by checking against the table +# built during the initialization +# 4. add back the sign + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# # 1. calculate the sign of chi from the q_samples +# sign_samples = jnp.where( +# q_samples >= 0.5, +# jnp.zeros_like(q_samples) + 1.0, +# jnp.zeros_like(q_samples) - 1.0, +# ) +# # 2. remap q_samples +# q_samples = jnp.where( +# q_samples >= 0.5, +# 2 * (q_samples - 0.5), +# 2 * (0.5 - q_samples), +# ) +# # 3. map the quantile to chi via interpolation +# samples = jnp.interp( +# q_samples, +# self.cdf_vals, +# self.chi_axis, +# ) +# # 4. add back the sign +# samples *= sign_samples + +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_p = jnp.where( +# (variable >= self.amax) | (variable <= -self.amax), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), +# ) +# return log_p diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py new file mode 100644 index 00000000..b6070510 --- /dev/null +++ b/src/jimgw/single_event/transforms.py @@ -0,0 +1,124 @@ +from abc import ABC +from typing import Callable + +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped +from astropy.time import Time + +from jimgw.single_event.detector import GroundBased2G +from jimgw.transforms import BijectiveTransform +from jimgw.single_event.utils import ( + Mc_q_to_m1_m2, + m1_m2_to_Mc_q, + q_to_eta, + eta_to_q, + ra_dec_to_zenith_azimuth, + zenith_azimuth_to_ra_dec, + euler_rotation, +) + +@jaxtyped(typechecker=typechecker) +class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to component masses + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + assert "m_1" in name_mapping[0] and "m_2" in name_mapping[0] and "M_c" in name_mapping[1] and "q" in name_mapping[1] + + def named_transform(x): + Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) + return {"M_c": Mc, "q": q} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} + + self.inverse_transform_func = named_inverse_transform + + +@jaxtyped(typechecker=typechecker) +class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform mass ratio to symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + self.transform_func = lambda x: { + name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]]) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]]) + } + + +@jaxtyped(typechecker=typechecker) +class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): + """ + Transform sky frame to detector frame sky position + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + rotation: Float[Array, " 3 3"] + rotation_inv: Float[Array, " 3 3"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifos: GroundBased2G, + ): + super().__init__(name_mapping) + + self.gmst = Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + delta_x = ifos[0].vertex - ifos[1].vertex + self.rotation = euler_rotation(delta_x) + self.rotation_inv = jnp.linalg.inv(self.rotation) + + assert "ra" in name_mapping[0] and "dec" in name_mapping[0] and "zenith" in name_mapping[1] and "azimuth" in name_mapping[1] + + def named_transform(x): + zenith, azimuth = ra_dec_to_zenith_azimuth( + x["ra"], x["dec"], self.gmst, self.rotation + ) + return {"zenith": zenith, "azimuth": azimuth} + + self.transform_func = named_transform + + def named_inverse_transform(x): + ra, dec = zenith_azimuth_to_ra_dec( + x["zenith"], x["azimuth"], self.gmst, self.rotation_inv + ) + return {"ra": ra, "dec": dec} + + self.inverse_transform_func = named_inverse_transform diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 8721c176..0d1313ca 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -259,8 +259,8 @@ def angle_rotation( Zenith angle. azimuth : Float Azimuthal angle. - delta_x : Float - The vector pointing from the first detector to the second detector. + rotation : Float[Array, " 3 3"] + The rotation matrix. Returns ------- @@ -334,8 +334,8 @@ def zenith_azimuth_to_ra_dec( Azimuthal angle. gmst : Float Greenwich mean sidereal time. - delta_x : Float - The vector pointing from the first detector to the second detector. + rotation : Float[Array, " 3 3"] + The rotation matrix. Copied and modified from bilby/gw/utils.py @@ -392,8 +392,8 @@ def ra_dec_to_zenith_azimuth( Declination. gmst : Float Greenwich mean sidereal time. - delta_x : Float - The vector pointing from the first detector to the second detector. + rotation : Float[Array, " 3 3"] + The rotation matrix. Returns ------- diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 033e4882..4c67a076 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -5,18 +5,7 @@ import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped -from astropy.time import Time -from jimgw.single_event.detector import GroundBased2G -from jimgw.single_event.utils import ( - Mc_q_to_m1_m2, - m1_m2_to_Mc_q, - q_to_eta, - eta_to_q, - ra_dec_to_zenith_azimuth, - zenith_azimuth_to_ra_dec, - euler_rotation, -) class Transform(ABC): @@ -346,7 +335,9 @@ def logit(x): for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: (self.original_upper_bound[i] - self.original_lower_bound[i]) + name_mapping[0][i]: ( + self.original_upper_bound[i] - self.original_lower_bound[i] + ) / (1 + jnp.exp(-x[name_mapping[1][i]])) + self.original_lower_bound[i] for i in range(len(name_mapping[1])) @@ -380,111 +371,6 @@ def __init__( } -@jaxtyped(typechecker=typechecker) -class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): - """ - Transform chirp mass and mass ratio to component masses - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - assert "M_c" in name_mapping[0] and "q" in name_mapping[0] and "m_1" in name_mapping[1] and "m_2" in name_mapping[1] - - def named_transform(x): - m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.transform_func = named_transform - - def named_inverse_transform(x): - Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) - return {"M_c": Mc, "q": q} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): - """ - Transform mass ratio to symmetric mass ratio - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - - self.transform_func = lambda x: { - name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]]) - } - self.inverse_transform_func = lambda x: { - name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]]) - } - - -@jaxtyped(typechecker=typechecker) -class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): - """ - Transform sky frame to detector frame sky position - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - gmst: Float - rotation: Float[Array, " 3 3"] - rotation_inv: Float[Array, " 3 3"] - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - gps_time: Float, - ifos: GroundBased2G, - ): - super().__init__(name_mapping) - - self.gmst = Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad - delta_x = ifos[0].vertex - ifos[1].vertex - self.rotation = euler_rotation(delta_x) - self.rotation_inv = jnp.linalg.inv(self.rotation) - - assert "ra" in name_mapping[0] and "dec" in name_mapping[0] and "zenith" in name_mapping[1] and "azimuth" in name_mapping[1] - - def named_transform(x): - zenith, azimuth = ra_dec_to_zenith_azimuth( - x["ra"], x["dec"], self.gmst, self.rotation - ) - return {"zenith": zenith, "azimuth": azimuth} - - self.transform_func = named_transform - - def named_inverse_transform(x): - ra, dec = zenith_azimuth_to_ra_dec( - x["zenith"], x["azimuth"], self.gmst, self.rotation_inv - ) - return {"ra": ra, "dec": dec} - - self.inverse_transform_func = named_inverse_transform - - class PowerLawTransform(BijectiveTransform): """ PowerLaw transformation From 3e1ea7172d6d9487298c134bd5f567607656c7b6 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 12:27:42 -0400 Subject: [PATCH 124/172] Tidy up test --- src/jimgw/single_event/prior.py | 37 +++------ src/jimgw/single_event/transforms.py | 75 ++++++++++++++---- src/jimgw/single_event/utils.py | 112 ++++++++++++++++++++++----- src/jimgw/transforms.py | 1 - test/integration/test_GW150914.py | 13 ++-- 5 files changed, 166 insertions(+), 72 deletions(-) diff --git a/src/jimgw/single_event/prior.py b/src/jimgw/single_event/prior.py index e0173482..51a754eb 100644 --- a/src/jimgw/single_event/prior.py +++ b/src/jimgw/single_event/prior.py @@ -1,12 +1,15 @@ -from dataclasses import field - -import jax import jax.numpy as jnp from beartype import beartype as typechecker -from flowMC.nfmodel.base import Distribution -from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped +from jaxtyping import jaxtyped + +from jimgw.prior import ( + Prior, + CombinePrior, + UniformPrior, + PowerLawPrior, + SinePrior, +) -from jimgw.prior import Prior, CombinePrior, UniformPrior, PowerLawPrior, SinePrior, CosinePrior @jaxtyped(typechecker=typechecker) class UniformSpherePrior(CombinePrior): @@ -30,28 +33,6 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) -@jaxtyped(typechecker=typechecker) -class UniformComponentMassPrior(CombinePrior): - """ - A prior in the range [xmin, xmax) for component masses which assumes the - component masses to be uniformly distributed. - """ - - def __repr__(self): - return f"UniformComponentMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" - - def __init__(self, xmin: float, xmax: float): - self.parameter_names = ["m_1", "m_2"] - super().__init__( - [ - UniformPrior(xmin, xmax, ["m_1"]), - UniformPrior(xmin, xmax, ["m_2"]), - ] - ) - - def log_prob(self, z: dict[str, Float]) -> Float: - output = super().log_prob(z) - output += jnp.log(2.) @jaxtyped(typechecker=typechecker) class UniformComponentChirpMassPrior(PowerLawPrior): diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index b6070510..fe12c52a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -1,7 +1,3 @@ -from abc import ABC -from typing import Callable - -import jax import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped @@ -10,8 +6,10 @@ from jimgw.single_event.detector import GroundBased2G from jimgw.transforms import BijectiveTransform from jimgw.single_event.utils import ( - Mc_q_to_m1_m2, m1_m2_to_Mc_q, + Mc_q_to_m1_m2, + m1_m2_to_Mc_eta, + Mc_eta_to_m1_m2, q_to_eta, eta_to_q, ra_dec_to_zenith_azimuth, @@ -19,6 +17,7 @@ euler_rotation, ) + @jaxtyped(typechecker=typechecker) class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): """ @@ -35,7 +34,12 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - assert "m_1" in name_mapping[0] and "m_2" in name_mapping[0] and "M_c" in name_mapping[1] and "q" in name_mapping[1] + assert ( + "m_1" in name_mapping[0] + and "m_2" in name_mapping[0] + and "M_c" in name_mapping[1] + and "q" in name_mapping[1] + ) def named_transform(x): Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) @@ -50,6 +54,43 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform mass ratio to symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + assert ( + "m_1" in name_mapping[0] + and "m_2" in name_mapping[0] + and "M_c" in name_mapping[1] + and "eta" in name_mapping[1] + ) + + def named_transform(x): + Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) + return {"M_c": Mc, "eta": eta} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): """ @@ -67,13 +108,10 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) + assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0] - self.transform_func = lambda x: { - name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]]) - } - self.inverse_transform_func = lambda x: { - name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]]) - } + self.transform_func = lambda x: {"eta": q_to_eta(x["q"])} + self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])} @jaxtyped(typechecker=typechecker) @@ -100,12 +138,19 @@ def __init__( ): super().__init__(name_mapping) - self.gmst = Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) delta_x = ifos[0].vertex - ifos[1].vertex self.rotation = euler_rotation(delta_x) self.rotation_inv = jnp.linalg.inv(self.rotation) - - assert "ra" in name_mapping[0] and "dec" in name_mapping[0] and "zenith" in name_mapping[1] and "azimuth" in name_mapping[1] + + assert ( + "ra" in name_mapping[0] + and "dec" in name_mapping[0] + and "zenith" in name_mapping[1] + and "azimuth" in name_mapping[1] + ) def named_transform(x): zenith, azimuth = ra_dec_to_zenith_azimuth( diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 0d1313ca..16d6cc7c 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -56,12 +56,12 @@ def m1_m2_to_M_q(m1: Float, m2: Float): q : Float Mass ratio. """ - M_tot = jnp.log(m1 + m2) - q = jnp.log(m2 / m1) - jnp.log(1 - m2 / m1) + M_tot = m1 + m2 + q = m2 / m1 return M_tot, q -def M_q_to_m1_m2(trans_M_tot: Float, trans_q: Float): +def M_q_to_m1_m2(M_tot: Float, q: Float): """ Transforming the Total mass M and mass ratio q to the primary mass m1 and secondary mass m2. @@ -80,13 +80,37 @@ def M_q_to_m1_m2(trans_M_tot: Float, trans_q: Float): m2 : Float Secondary mass. """ - M_tot = jnp.exp(trans_M_tot) - q = 1.0 / (1 + jnp.exp(-trans_q)) m1 = M_tot / (1 + q) m2 = m1 * q return m1, m2 +def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c + and mass ratio q. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + """ + M_tot = m1 + m2 + eta = m1 * m2 / M_tot**2 + M_c = M_tot * eta ** (3.0 / 5) + q = m2 / m1 + return M_c, q + + def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: """ Transforming the chirp mass M_c and mass ratio q to the primary mass m1 and @@ -113,10 +137,10 @@ def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: return m1, m2 -def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: +def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: """ - Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c - and mass ratio q. + Transforming the primary mass m1 and secondary mass m2 to the total mass M + and symmetric mass ratio eta. Parameters ---------- @@ -127,21 +151,43 @@ def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: Returns ------- - M_c : Float - Chirp mass. - q : Float - Mass ratio. + M : Float + Total mass. + eta : Float + Symmetric mass ratio. """ M_tot = m1 + m2 eta = m1 * m2 / M_tot**2 - M_c = M_tot * eta ** (3.0 / 5) - q = m2 / m1 - return M_c, q + return M_tot, eta -def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: +def M_eta_to_m1_m2(M_tot: Float, eta: Float) -> tuple[Float, Float]: """ - Transforming the primary mass m1 and secondary mass m2 to the total mass M + Transforming the total mass M and symmetric mass ratio eta to the primary mass m1 + and secondary mass m2. + + Parameters + ---------- + M : Float + Total mass. + eta : Float + Symmetric mass ratio. + + Returns + ------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + """ + m1 = M_tot * (1 + jnp.sqrt(1 - 4 * eta)) / 2 + m2 = M_tot * (1 - jnp.sqrt(1 - 4 * eta)) / 2 + return m1, m2 + + +def m1_m2_to_Mc_eta(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c and symmetric mass ratio eta. Parameters @@ -153,14 +199,40 @@ def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: Returns ------- - M : Float - Total mass. + M_c : Float + Chirp mass. eta : Float Symmetric mass ratio. """ M = m1 + m2 eta = m1 * m2 / M**2 - return M, eta + M_c = M * eta ** (3.0 / 5) + return M_c, eta + + +def Mc_eta_to_m1_m2(M_c: Float, eta: Float) -> tuple[Float, Float]: + """ + Transforming the chirp mass M_c and symmetric mass ratio eta to the primary mass m1 + and secondary mass m2. + + Parameters + ---------- + M_c : Float + Chirp mass. + eta : Float + Symmetric mass ratio. + + Returns + ------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + """ + M = M_c / eta ** (3.0 / 5) + m1 = M * (1 + jnp.sqrt(1 - 4 * eta)) / 2 + m2 = M * (1 - jnp.sqrt(1 - 4 * eta)) / 2 + return m1, m2 def q_to_eta(q: Float) -> Float: diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4c67a076..715d49de 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -7,7 +7,6 @@ from jaxtyping import Float, Array, jaxtyped - class Transform(ABC): """ Base class for transform. diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index f29b8ad6..6fddf9ea 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,7 +8,9 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD -from jimgw.transforms import BoundToUnbound, MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -34,14 +36,9 @@ ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior( - 0.125, - 1., - parameter_names=["q"], -) +q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) -# Current likelihood sampling will fail and give nan because of large number dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) @@ -77,7 +74,7 @@ BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), - BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] From 6ad882a40df8358a00b63ff794467fca46a25396 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 12:38:27 -0400 Subject: [PATCH 125/172] Add utils.py --- src/jimgw/utils.py | 13 +++++++++++++ test/unit/test_prior.py | 1 + 2 files changed, 14 insertions(+) create mode 100644 src/jimgw/utils.py diff --git a/src/jimgw/utils.py b/src/jimgw/utils.py new file mode 100644 index 00000000..e629903a --- /dev/null +++ b/src/jimgw/utils.py @@ -0,0 +1,13 @@ +from jimgw.prior import Prior + +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index c42d76be..852ded16 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -1,4 +1,5 @@ from jimgw.prior import * +from jimgw.utils import trace_prior_parent import scipy.stats as stats From ede2b99faa70219201b17147477041b2fe9bb71f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 12:39:40 -0400 Subject: [PATCH 126/172] Move log_i0 --- src/jimgw/single_event/utils.py | 20 -------------------- src/jimgw/utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 16d6cc7c..b2fb8c25 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -1,6 +1,5 @@ import jax.numpy as jnp from jax.scipy.integrate import trapezoid -from jax.scipy.special import i0e from jaxtyping import Array, Float @@ -477,22 +476,3 @@ def ra_dec_to_zenith_azimuth( theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) zenith, azimuth = angle_rotation(theta, phi, rotation) return zenith, azimuth - - -def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: - """ - A numerically stable method to evaluate log of - a modified Bessel function of order 0. - It is used in the phase-marginalized likelihood. - - Parameters - ========== - x: array-like - Value(s) at which to evaluate the function - - Returns - ======= - array-like: - The natural logarithm of the bessel function - """ - return jnp.log(i0e(x)) + x diff --git a/src/jimgw/utils.py b/src/jimgw/utils.py index e629903a..70c6e166 100644 --- a/src/jimgw/utils.py +++ b/src/jimgw/utils.py @@ -1,5 +1,10 @@ +import jax.numpy as jnp +from jax.scipy.special import i0e +from jaxtyping import Array, Float + from jimgw.prior import Prior + def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: if prior.composite: if isinstance(prior.base_prior, list): @@ -11,3 +16,22 @@ def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: output.append(prior) return output + + +def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: + """ + A numerically stable method to evaluate log of + a modified Bessel function of order 0. + It is used in the phase-marginalized likelihood. + + Parameters + ========== + x: array-like + Value(s) at which to evaluate the function + + Returns + ======= + array-like: + The natural logarithm of the bessel function + """ + return jnp.log(i0e(x)) + x From e76469611477fd80b3e9f16d19aba21f25132bc7 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 12:43:12 -0400 Subject: [PATCH 127/172] Fixing check --- src/jimgw/single_event/likelihood.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index f10aeed1..ce2e8f0e 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -12,7 +12,7 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior from jimgw.single_event.detector import Detector -from jimgw.single_event.utils import log_i0 +from jimgw.utils import log_i0 from jimgw.single_event.waveform import Waveform From 093dd090944062e64db88eef788d0a7a4cf006ba Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:51:20 -0400 Subject: [PATCH 128/172] Updated jim.py --- src/jimgw/jim.py | 87 +++++++++++------------------------------------- 1 file changed, 19 insertions(+), 68 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 7d86fecf..a31e614b 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -135,7 +135,7 @@ def negative_posterior(x: Float[Array, " n_dim"]): best_fit = optimizer.get_result()[0] return best_fit - def print_summary(self): + def print_summary(self, transform: bool = True): """ Generate summary of the run @@ -144,49 +144,21 @@ def print_summary(self): train_summary = self.sampler.get_sampler_state(training=True) production_summary = self.sampler.get_sampler_state(training=False) - training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names)) - if self.sample_transforms: - # Need rewrite to vectorize - transformed_chain = {} - named_sample = self.add_name(training_chain[0]) - for transform in self.sample_transforms: - named_sample = transform.backward(named_sample) - for key, value in named_sample.items(): - transformed_chain[key] = [value] - for sample in training_chain[1:]: - named_sample = self.add_name(sample) - for transform in self.sample_transforms: - named_sample = transform.backward(named_sample) - for key, value in named_sample.items(): - transformed_chain[key].append(value) - training_chain = transformed_chain - else: - training_chain = self.add_name(training_chain) + training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T + training_chain = self.add_name(training_chain) + if transform: + for sample_transform in self.sample_transforms: + training_chain = sample_transform.backward(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape( - -1, len(self.parameter_names) - ) - if self.sample_transforms: - # Need rewrite to vectorize - transformed_chain = {} - named_sample = self.add_name(production_chain[0]) - for transform in self.sample_transforms: - named_sample = transform.backward(named_sample) - for key, value in named_sample.items(): - transformed_chain[key] = [value] - for sample in production_chain[1:]: - named_sample = self.add_name(sample) - for transform in self.sample_transforms: - named_sample = transform.backward(named_sample) - for key, value in named_sample.items(): - transformed_chain[key].append(value) - production_chain = transformed_chain - else: - production_chain = self.add_name(production_chain) + production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T + production_chain = self.add_name(production_chain) + if transform: + for sample_transform in self.sample_transforms: + production_chain = sample_transform.backward(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -194,9 +166,7 @@ def print_summary(self): print("Training summary") print("=" * 10) for key, value in training_chain.items(): - print( - f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}" - ) + print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") print( f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" ) @@ -213,9 +183,7 @@ def print_summary(self): print("Production summary") print("=" * 10) for key, value in production_chain.items(): - print( - f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}" - ) + print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") print( f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" ) @@ -242,32 +210,15 @@ def get_samples(self, training: bool = False) -> dict: """ if training: - chains = self.sampler.get_sampler_state(training=True)["chains"] + chains = self.sampler.get_sampler_state(training=True)["chains"] # (500, 10100, 15) else: chains = self.sampler.get_sampler_state(training=False)["chains"] - # Need rewrite to output chains instead of flattened samples and vectorize - chains = chains.reshape(-1, len(self.parameter_names)) - if self.sample_transforms: - transformed_chain = {} - named_sample = self.add_name(chains[0]) - for transform in self.sample_transforms: - named_sample = transform.backward(named_sample) - for key, value in named_sample.items(): - transformed_chain[key] = [value] - for sample in chains[1:]: - named_sample = self.add_name(sample) - for transform in self.sample_transforms: - named_sample = transform.backward(named_sample) - for key, value in named_sample.items(): - transformed_chain[key].append(value) - output = transformed_chain - else: - output = self.add_name(chains) - - for key in output.keys(): - output[key] = jnp.array(output[key]) - return output + chains = chains.transpose(2, 0, 1) + chains = self.add_name(chains) + for sample_transform in self.sample_transforms: + chains = sample_transform.backward(chains) + return chains def plot(self): pass From ceb0b7f743c93a2bf011ccbb2b385a48d2b20d15 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:12:57 -0400 Subject: [PATCH 129/172] Added spin transform --- src/jimgw/single_event/transforms.py | 54 ++++++++- src/jimgw/single_event/utils.py | 158 ++++++++++++++++++++++++++ test/integration/test_GW150914_pv2.py | 141 +++++++++++++++++++++++ 3 files changed, 352 insertions(+), 1 deletion(-) create mode 100644 test/integration/test_GW150914_pv2.py diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index fe12c52a..2f7805e4 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,7 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import BijectiveTransform +from jimgw.transforms import BijectiveTransform, NtoNTransform from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -15,6 +15,7 @@ ra_dec_to_zenith_azimuth, zenith_azimuth_to_ra_dec, euler_rotation, + spin_to_cartesian_spin, ) @@ -167,3 +168,54 @@ def named_inverse_transform(x): return {"ra": ra, "dec": dec} self.inverse_transform_func = named_inverse_transform + + +@jaxtyped(typechecker=typechecker) +class SpinToCartesianSpinTransform(NtoNTransform): + """ + Spin to Cartesian spin transformation + """ + + freq_ref: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + freq_ref: Float, + ): + super().__init__(name_mapping) + + self.freq_ref = freq_ref + + assert ( + "theta_jn" in name_mapping[0] + and "phi_jl" in name_mapping[0] + and "theta1" in name_mapping[0] + and "theta2" in name_mapping[0] + and "phi12" in name_mapping[0] + and "a1" in name_mapping[0] + and "a2" in name_mapping[0] + and "iota" in name_mapping[1] + and "s1_x" in name_mapping[1] + and "s1_y" in name_mapping[1] + and "s1_z" in name_mapping[1] + and "s2_x" in name_mapping[1] + and "s2_y" in name_mapping[1] + and "s2_z" in name_mapping[1] + ) + + def named_transform(x): + iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( + x["theta_jn"], x["phi_jl"], x["theta1"], x["theta2"], x["phi12"], x["a1"], x["a2"], x['M_c'], x['q'], self.freq_ref, x['phase_c'] + ) + return { + "iota": iota, + "s1_x": s1x, + "s1_y": s1y, + "s1_z": s1z, + "s2_x": s2x, + "s2_y": s2y, + "s2_z": s2z, + } + + self.transform_func = named_transform diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index b2fb8c25..396e5474 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -391,6 +391,164 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F return ra, dec +def spin_to_cartesian_spin( + thetaJN: Float, + phiJL: Float, + theta1: Float, + theta2: Float, + phi12: Float, + chi1: Float, + chi2: Float, + M_c: Float, + q: Float, + fRef: Float, + phiRef: Float, +) -> tuple[Float, Float, Float, Float, Float, Float, Float]: + """ + Transforming the spin parameters + + The code is based on the approach used in LALsimulation: + https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html + + Parameters: + ------- + thetaJN: Float + Zenith angle between the total angular momentum and the line of sight + phiJL: Float + Difference between total and orbital angular momentum azimuthal angles + theta1: Float + Zenith angle between the spin and orbital angular momenta for the primary object + theta2: Float + Zenith angle between the spin and orbital angular momenta for the secondary object + phi12: Float + Difference between the azimuthal angles of the individual spin vector projections + onto the orbital plane + chi1: Float + Primary object aligned spin: + chi2: Float + Secondary object aligned spin: + M_c: Float + The chirp mass + eta: Float + The symmetric mass ratio + fRef: Float + The reference frequency + phiRef: Float + Binary phase at a reference frequency + + Returns: + ------- + iota: Float + Zenith angle between the orbital angular momentum and the line of sight + S1x: Float + The x-component of the primary spin + S1y: Float + The y-component of the primary spin + S1z: Float + The z-component of the primary spin + S2x: Float + The x-component of the secondary spin + S2y: Float + The y-component of the secondary spin + S2z: Float + The z-component of the secondary spin + """ + + def rotate_y(angle, vec): + """ + Rotate the vector (x, y, z) about y-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + def rotate_z(angle, vec): + """ + Rotate the vector (x, y, z) about z-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + LNh = jnp.array([0.0, 0.0, 1.0]) + + s1hat = jnp.array( + [ + jnp.sin(theta1) * jnp.cos(phiRef), + jnp.sin(theta1) * jnp.sin(phiRef), + jnp.cos(theta1), + ] + ) + s2hat = jnp.array( + [ + jnp.sin(theta2) * jnp.cos(phi12 + phiRef), + jnp.sin(theta2) * jnp.sin(phi12 + phiRef), + jnp.cos(theta2), + ] + ) + + m1, m2 = Mc_q_to_m1_m2(M_c, q) + eta = q / (1 + q) ** 2 + v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) + + Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) + s1 = m1 * m1 * chi1 * s1hat + s2 = m2 * m2 * chi2 * s2hat + J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) + + Jhat = J / jnp.linalg.norm(J) + theta0 = jnp.arccos(Jhat[2]) + phi0 = jnp.arctan2(Jhat[1], Jhat[0]) + + # Rotation 1: + s1hat = rotate_z(-phi0, s1hat) + s2hat = rotate_z(-phi0, s2hat) + + # Rotation 2: + LNh = rotate_y(-theta0, LNh) + s1hat = rotate_y(-theta0, s1hat) + s2hat = rotate_y(-theta0, s2hat) + + # Rotation 3: + LNh = rotate_z(phiJL - jnp.pi, LNh) + s1hat = rotate_z(phiJL - jnp.pi, s1hat) + s2hat = rotate_z(phiJL - jnp.pi, s2hat) + + # Compute iota + N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) + iota = jnp.arccos(jnp.dot(N, LNh)) + + thetaLJ = jnp.arccos(LNh[2]) + phiL = jnp.arctan2(LNh[1], LNh[0]) + + # Rotation 4: + s1hat = rotate_z(-phiL, s1hat) + s2hat = rotate_z(-phiL, s2hat) + N = rotate_z(-phiL, N) + + # Rotation 5: + s1hat = rotate_y(-thetaLJ, s1hat) + s2hat = rotate_y(-thetaLJ, s2hat) + N = rotate_y(-thetaLJ, N) + + # Rotation 6: + phiN = jnp.arctan2(N[1], N[0]) + s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat) + s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat) + + S1 = s1hat * chi1 + S2 = s2hat * chi2 + return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] + + def zenith_azimuth_to_ra_dec( zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: diff --git a/test/integration/test_GW150914_pv2.py b/test/integration/test_GW150914_pv2.py new file mode 100644 index 00000000..da2e510d --- /dev/null +++ b/test/integration/test_GW150914_pv2.py @@ -0,0 +1,141 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) +theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) +phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) +theta_1_prior = SinePrior(parameter_names=["theta_1"]) +theta_2_prior = SinePrior(parameter_names=["theta_2"]) +phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) +a1_prior = UniformPrior(0.0, 1.0, parameter_names=["a1"]) +a2_prior = UniformPrior(0.0, 1.0, parameter_names=["a2"]) +dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + Mc_prior, + q_prior, + theta_jn_prior, + phi_jl_prior, + theta_1_prior, + theta_2_prior, + phi_12_prior, + a1_prior, + a2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["a1"], ["a1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["a2"], ["a2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) +] + +likelihood_transforms = [ + SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a1", "a2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(15) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) \ No newline at end of file From b972aead4452fc5b8e1d90c716cffe937003f559 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:17:05 -0400 Subject: [PATCH 130/172] Updated transform.py --- src/jimgw/single_event/transforms.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 2f7805e4..71e52259 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -190,11 +190,11 @@ def __init__( assert ( "theta_jn" in name_mapping[0] and "phi_jl" in name_mapping[0] - and "theta1" in name_mapping[0] - and "theta2" in name_mapping[0] - and "phi12" in name_mapping[0] - and "a1" in name_mapping[0] - and "a2" in name_mapping[0] + and "theta_1" in name_mapping[0] + and "theta_2" in name_mapping[0] + and "phi_12" in name_mapping[0] + and "a_1" in name_mapping[0] + and "a_2" in name_mapping[0] and "iota" in name_mapping[1] and "s1_x" in name_mapping[1] and "s1_y" in name_mapping[1] From d4a4386a0e52393e88226e41f368f594005d75df Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:17:27 -0400 Subject: [PATCH 131/172] Updated test_GW150914_pv2.py --- test/integration/test_GW150914_pv2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/test_GW150914_pv2.py b/test/integration/test_GW150914_pv2.py index da2e510d..8d8f7884 100644 --- a/test/integration/test_GW150914_pv2.py +++ b/test/integration/test_GW150914_pv2.py @@ -90,7 +90,7 @@ ] likelihood_transforms = [ - SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a1", "a2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), + SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), ] From 0d6a5e5c44a66789b7f4ba288b26e57bf396f08c Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:19:00 -0400 Subject: [PATCH 132/172] Updated transform.py --- src/jimgw/single_event/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 71e52259..aac2fddb 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -206,7 +206,7 @@ def __init__( def named_transform(x): iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( - x["theta_jn"], x["phi_jl"], x["theta1"], x["theta2"], x["phi12"], x["a1"], x["a2"], x['M_c'], x['q'], self.freq_ref, x['phase_c'] + x["theta_jn"], x["phi_jl"], x["theta_1"], x["theta_2"], x["phi_12"], x["a_1"], x["a_2"], x['M_c'], x['q'], self.freq_ref, x['phase_c'] ) return { "iota": iota, From 91f2f8918130d6e586a3e114c9779c77288b1771 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:21:46 -0400 Subject: [PATCH 133/172] Updated test_GW150914_pv2.py --- test/integration/test_GW150914_pv2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/integration/test_GW150914_pv2.py b/test/integration/test_GW150914_pv2.py index 8d8f7884..753ec5b3 100644 --- a/test/integration/test_GW150914_pv2.py +++ b/test/integration/test_GW150914_pv2.py @@ -42,8 +42,8 @@ theta_1_prior = SinePrior(parameter_names=["theta_1"]) theta_2_prior = SinePrior(parameter_names=["theta_2"]) phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) -a1_prior = UniformPrior(0.0, 1.0, parameter_names=["a1"]) -a2_prior = UniformPrior(0.0, 1.0, parameter_names=["a2"]) +a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) +a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) @@ -60,8 +60,8 @@ theta_1_prior, theta_2_prior, phi_12_prior, - a1_prior, - a2_prior, + a_1_prior, + a_2_prior, dL_prior, t_c_prior, phase_c_prior, @@ -79,8 +79,8 @@ BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["a1"], ["a1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["a2"], ["a2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), From 6fc8887f467027f3ce45514f35e0b61e1ff1d279 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:21:59 -0400 Subject: [PATCH 134/172] Updated utils.py --- src/jimgw/single_event/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 396e5474..a15bd7bf 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -2,6 +2,8 @@ from jax.scipy.integrate import trapezoid from jaxtyping import Array, Float +from jimgw.constants import Msun + def inner_product( h1: Float[Array, " n_sample"], From 49d604d3db54a09376145449fc69480d3226c7f2 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 13:31:09 -0400 Subject: [PATCH 135/172] Fix mass transform test --- src/jimgw/jim.py | 16 ++++++++++++---- test/integration/test_GW150914.py | 19 +++++++++++-------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 7d86fecf..a13d0d3d 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -104,10 +104,18 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: - initial_guess_named = self.prior.sample(key, self.sampler.n_chains) - for transform in self.sample_transforms: - initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + initial_guess = [] + for i in range(self.sampler.n_chains): + flag = True + while flag: + key = jax.random.split(key)[1] + guess = self.prior.sample(key, 1) + for transform in self.sample_transforms: + guess = transform.forward(guess) + guess = jnp.array([i for i in guess.values()]).T[0] + flag = not jnp.all(jnp.isfinite(guess)) + initial_guess.append(guess) + initial_guess = jnp.array(initial_guess) self.sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 6fddf9ea..90cba0f1 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -9,7 +9,7 @@ from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam @@ -35,8 +35,10 @@ for ifo in ifos: ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) @@ -49,8 +51,8 @@ prior = CombinePrior( [ - Mc_prior, - q_prior, + m_1_prior, + m_2_prior, s1z_prior, s2z_prior, dL_prior, @@ -64,8 +66,9 @@ ) sample_transforms = [ - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), @@ -79,7 +82,7 @@ ] likelihood_transforms = [ - MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), ] likelihood = TransientLikelihoodFD( From 3b7f9e882532c2056632f1f7c23487c4c72779df Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:40:40 -0400 Subject: [PATCH 136/172] Reformated --- src/jimgw/jim.py | 2 +- src/jimgw/single_event/transforms.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index a31e614b..b6c91582 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -210,7 +210,7 @@ def get_samples(self, training: bool = False) -> dict: """ if training: - chains = self.sampler.get_sampler_state(training=True)["chains"] # (500, 10100, 15) + chains = self.sampler.get_sampler_state(training=True)["chains"] else: chains = self.sampler.get_sampler_state(training=False)["chains"] diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index aac2fddb..c3e77846 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -168,7 +168,7 @@ def named_inverse_transform(x): return {"ra": ra, "dec": dec} self.inverse_transform_func = named_inverse_transform - + @jaxtyped(typechecker=typechecker) class SpinToCartesianSpinTransform(NtoNTransform): @@ -177,16 +177,16 @@ class SpinToCartesianSpinTransform(NtoNTransform): """ freq_ref: Float - + def __init__( self, name_mapping: tuple[list[str], list[str]], freq_ref: Float, ): super().__init__(name_mapping) - + self.freq_ref = freq_ref - + assert ( "theta_jn" in name_mapping[0] and "phi_jl" in name_mapping[0] @@ -203,10 +203,20 @@ def __init__( and "s2_y" in name_mapping[1] and "s2_z" in name_mapping[1] ) - + def named_transform(x): iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( - x["theta_jn"], x["phi_jl"], x["theta_1"], x["theta_2"], x["phi_12"], x["a_1"], x["a_2"], x['M_c'], x['q'], self.freq_ref, x['phase_c'] + x["theta_jn"], + x["phi_jl"], + x["theta_1"], + x["theta_2"], + x["phi_12"], + x["a_1"], + x["a_2"], + x["M_c"], + x["q"], + self.freq_ref, + x["phase_c"], ) return { "iota": iota, @@ -217,5 +227,5 @@ def named_transform(x): "s2_y": s2y, "s2_z": s2z, } - + self.transform_func = named_transform From ac0b1f57ab4ebb07e7d984f79735e9a03e49be86 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 13:43:49 -0400 Subject: [PATCH 137/172] Use PowerLaw for distance --- test/integration/{test_GW150914.py => test_GW150914_D.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename test/integration/{test_GW150914.py => test_GW150914_D.py} (98%) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914_D.py similarity index 98% rename from test/integration/test_GW150914.py rename to test/integration/test_GW150914_D.py index 90cba0f1..9a39434e 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914_D.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD @@ -41,7 +41,7 @@ m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) -dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) iota_prior = SinePrior(parameter_names=["iota"]) From e0a61fa163899f8c74ad8db1d70225be527e7704 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:44:44 -0400 Subject: [PATCH 138/172] Reformated --- test/integration/test_GW150914_pv2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration/test_GW150914_pv2.py b/test/integration/test_GW150914_pv2.py index 753ec5b3..5a15bd9b 100644 --- a/test/integration/test_GW150914_pv2.py +++ b/test/integration/test_GW150914_pv2.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD @@ -44,7 +44,7 @@ phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) -dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) From 63747724ab11c3091c986e0781592ffbcd88cb9f Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:26:40 -0400 Subject: [PATCH 139/172] Rename test_GW150914_pv2.py to test_GW150914_PV2.py --- test/integration/{test_GW150914_pv2.py => test_GW150914_PV2.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename test/integration/{test_GW150914_pv2.py => test_GW150914_PV2.py} (99%) diff --git a/test/integration/test_GW150914_pv2.py b/test/integration/test_GW150914_PV2.py similarity index 99% rename from test/integration/test_GW150914_pv2.py rename to test/integration/test_GW150914_PV2.py index 5a15bd9b..6be02936 100644 --- a/test/integration/test_GW150914_pv2.py +++ b/test/integration/test_GW150914_PV2.py @@ -138,4 +138,4 @@ strategies=[Adam_optimizer, "default"], ) -jim.sample(jax.random.PRNGKey(42)) \ No newline at end of file +jim.sample(jax.random.PRNGKey(42)) From fa134f9b0f575f7d5921e7083afa5b4c81b99400 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 15:04:51 -0400 Subject: [PATCH 140/172] Add heterodyne test --- src/jimgw/jim.py | 2 +- src/jimgw/single_event/likelihood.py | 66 +++++++-- test/integration/test_GW150914_D.py | 4 - .../integration/test_GW150914_D_heterodyne.py | 131 ++++++++++++++++++ 4 files changed, 184 insertions(+), 19 deletions(-) create mode 100644 test/integration/test_GW150914_D_heterodyne.py diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index a13d0d3d..2063b0bf 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -105,7 +105,7 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess = [] - for i in range(self.sampler.n_chains): + for _ in range(self.sampler.n_chains): flag = True while flag: key = jax.random.split(key)[1] diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index ce2e8f0e..0ccb1ce8 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -14,6 +14,7 @@ from jimgw.single_event.detector import Detector from jimgw.utils import log_i0 from jimgw.single_event.waveform import Waveform +from jimgw.transforms import BijectiveTransform, NtoMTransform class SingleEventLiklihood(LikelihoodBase): @@ -184,8 +185,6 @@ def __init__( self, detectors: list[Detector], waveform: Waveform, - prior: Prior, - bounds: Float[Array, " n_dim 2"], n_bins: int = 100, trigger_time: float = 0, duration: float = 4, @@ -194,6 +193,9 @@ def __init__( n_steps: int = 2000, ref_params: dict = {}, reference_waveform: Optional[Waveform] = None, + prior: Optional[Prior] = None, + sample_transforms: Optional[list[BijectiveTransform]] = [], + likelihood_transforms: Optional[list[NtoMTransform]] = [], **kwargs, ) -> None: super().__init__( @@ -254,17 +256,24 @@ def __init__( ) self.freq_grid_low = freq_grid[:-1] - if not ref_params: + if ref_params: + self.ref_params = ref_params + print(f"Reference parameters provided, which are {self.ref_params}") + elif prior: print("No reference parameters are provided, finding it...") ref_params = self.maximize_likelihood( - bounds=bounds, prior=prior, popsize=popsize, n_steps=n_steps + prior=prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + popsize=popsize, + n_steps=n_steps, ) self.ref_params = {key: float(value) for key, value in ref_params.items()} print(f"The reference parameters are {self.ref_params}") else: - self.ref_params = ref_params - print(f"Reference parameters provided, which are {self.ref_params}") - + raise ValueError( + "Either reference parameters or parameter names must be provided" + ) # safe guard for the reference parameters # since ripple cannot handle eta=0.25 if jnp.isclose(self.ref_params["eta"], 0.25): @@ -542,25 +551,54 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): def maximize_likelihood( self, prior: Prior, - bounds: Float[Array, " n_dim 2"], + likelihood_transforms: list[BijectiveTransform], + sample_transforms: list[NtoMTransform], popsize: int = 100, n_steps: int = 2000, ): + parameter_names = prior.parameter_names + for transform in sample_transforms: + parameter_names = transform.propagate_name(parameter_names) + def y(x): - return -self.evaluate_original(prior.transform(prior.add_name(x)), {}) + named_params = dict(zip(parameter_names, x)) + for transform in reversed(sample_transforms): + named_params = transform.backward(named_params) + for transform in likelihood_transforms: + named_params = transform.forward(named_params) + return -self.evaluate_original(named_params, {}) print("Starting the optimizer") + optimizer = optimization_Adam( n_steps=n_steps, learning_rate=0.001, noise_level=1 ) - initial_position = jnp.array( - list(prior.sample(jax.random.PRNGKey(0), popsize).values()) - ).T + + key = jax.random.PRNGKey(0) + initial_position = [] + for _ in range(popsize): + flag = True + while flag: + key = jax.random.split(key)[1] + guess = prior.sample(key, 1) + for transform in sample_transforms: + guess = transform.forward(guess) + guess = jnp.array([i for i in guess.values()]).T[0] + flag = not jnp.all(jnp.isfinite(guess)) + initial_position.append(guess) + initial_position = jnp.array(initial_position) rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position ) - best_fit = optimized_positions[jnp.nanargmin(summary["final_log_prob"])] - return prior.transform(prior.add_name(best_fit)) + + best_fit = optimized_positions[jnp.argmin(summary["final_log_prob"])] + + named_params = dict(zip(parameter_names, best_fit)) + for transform in reversed(sample_transforms): + named_params = transform.backward(named_params) + for transform in likelihood_transforms: + named_params = transform.forward(named_params) + return named_params likelihood_presets = { diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py index 9a39434e..ba3ce2d6 100644 --- a/test/integration/test_GW150914_D.py +++ b/test/integration/test_GW150914_D.py @@ -1,5 +1,3 @@ -import time - import jax import jax.numpy as jnp @@ -19,8 +17,6 @@ ########## First we grab data ############# ########################################### -total_time_start = time.time() - # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 duration = 4 diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py new file mode 100644 index 00000000..66093b88 --- /dev/null +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -0,0 +1,131 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + +likelihood = HeterodynedTransientLikelihoodFD( + ifos, + prior=prior, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) From 2c9c6a62fac657e9d54137ff151e9cc66e24385c Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 15:21:02 -0400 Subject: [PATCH 141/172] Fix typecheck --- src/jimgw/jim.py | 8 ++++---- src/jimgw/single_event/likelihood.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2063b0bf..043c4672 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -102,8 +102,8 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = transform.forward(named_params) return self.likelihood.evaluate(named_params, data) + prior - def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): - if initial_guess.size == 0: + def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): + if initial_position.size == 0: initial_guess = [] for _ in range(self.sampler.n_chains): flag = True @@ -115,8 +115,8 @@ def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): guess = jnp.array([i for i in guess.values()]).T[0] flag = not jnp.all(jnp.isfinite(guess)) initial_guess.append(guess) - initial_guess = jnp.array(initial_guess) - self.sampler.sample(initial_guess, None) # type: ignore + initial_position = jnp.array(initial_guess) + self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( self, diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 0ccb1ce8..00e6ce6b 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -194,8 +194,8 @@ def __init__( ref_params: dict = {}, reference_waveform: Optional[Waveform] = None, prior: Optional[Prior] = None, - sample_transforms: Optional[list[BijectiveTransform]] = [], - likelihood_transforms: Optional[list[NtoMTransform]] = [], + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], **kwargs, ) -> None: super().__init__( @@ -551,8 +551,8 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): def maximize_likelihood( self, prior: Prior, - likelihood_transforms: list[BijectiveTransform], - sample_transforms: list[NtoMTransform], + likelihood_transforms: list[NtoMTransform], + sample_transforms: list[BijectiveTransform], popsize: int = 100, n_steps: int = 2000, ): From 15e40748844a9f5ae33bf60f3aa0880ff5632dfa Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 15:25:09 -0400 Subject: [PATCH 142/172] Shorten test runtime --- test/integration/test_GW150914_D_heterodyne.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py index 66093b88..b5945cee 100644 --- a/test/integration/test_GW150914_D_heterodyne.py +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -90,6 +90,8 @@ post_trigger_duration=2, sample_transforms=sample_transforms, likelihood_transforms=likelihood_transforms, + n_steps=5, + popsize=10, ) From 174d43a1fdc4fecb4a8b7dd6a51aadc102aff12a Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Sat, 3 Aug 2024 07:27:18 -0400 Subject: [PATCH 143/172] Fix jim output functions --- src/jimgw/jim.py | 6 +++--- test/integration/test_GW150914_D.py | 2 ++ test/integration/test_GW150914_D_heterodyne.py | 2 ++ test/integration/test_GW150914_PV2.py | 5 +++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index cc1ff14b..fae0bc98 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -155,7 +155,7 @@ def print_summary(self, transform: bool = True): training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T training_chain = self.add_name(training_chain) if transform: - for sample_transform in self.sample_transforms: + for sample_transform in reversed(self.sample_transforms): training_chain = sample_transform.backward(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] @@ -165,7 +165,7 @@ def print_summary(self, transform: bool = True): production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T production_chain = self.add_name(production_chain) if transform: - for sample_transform in self.sample_transforms: + for sample_transform in reversed(self.sample_transforms): production_chain = sample_transform.backward(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] @@ -224,7 +224,7 @@ def get_samples(self, training: bool = False) -> dict: chains = chains.transpose(2, 0, 1) chains = self.add_name(chains) - for sample_transform in self.sample_transforms: + for sample_transform in reversed(self.sample_transforms): chains = sample_transform.backward(chains) return chains diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py index ba3ce2d6..e1eee9ac 100644 --- a/test/integration/test_GW150914_D.py +++ b/test/integration/test_GW150914_D.py @@ -126,3 +126,5 @@ ) jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py index b5945cee..bf97efdb 100644 --- a/test/integration/test_GW150914_D_heterodyne.py +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -131,3 +131,5 @@ ) jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() diff --git a/test/integration/test_GW150914_PV2.py b/test/integration/test_GW150914_PV2.py index 6be02936..c9d83a5e 100644 --- a/test/integration/test_GW150914_PV2.py +++ b/test/integration/test_GW150914_PV2.py @@ -9,8 +9,7 @@ from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform -from jimgw.single_event.utils import Mc_q_to_m1_m2 +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -139,3 +138,5 @@ ) jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() From 14121d738d739079b395b7d18c59bc6417258a75 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Sat, 3 Aug 2024 07:29:26 -0400 Subject: [PATCH 144/172] Delete test_GW150914_PV2.py --- test/integration/test_GW150914_PV2.py | 142 -------------------------- 1 file changed, 142 deletions(-) delete mode 100644 test/integration/test_GW150914_PV2.py diff --git a/test/integration/test_GW150914_PV2.py b/test/integration/test_GW150914_PV2.py deleted file mode 100644 index c9d83a5e..00000000 --- a/test/integration/test_GW150914_PV2.py +++ /dev/null @@ -1,142 +0,0 @@ -import time - -import jax -import jax.numpy as jnp - -from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior -from jimgw.single_event.detector import H1, L1 -from jimgw.single_event.likelihood import TransientLikelihoodFD -from jimgw.single_event.waveform import RippleIMRPhenomD -from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform -from flowMC.strategy.optimization import optimization_Adam - -jax.config.update("jax_enable_x64", True) - -########################################### -########## First we grab data ############# -########################################### - -total_time_start = time.time() - -# first, fetch a 4s segment centered on GW150914 -gps = 1126259462.4 -duration = 4 -post_trigger_duration = 2 -start_pad = duration - post_trigger_duration -end_pad = post_trigger_duration -fmin = 20.0 -fmax = 1024.0 - -ifos = [H1, L1] - -for ifo in ifos: - ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) - -Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) -theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) -phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) -theta_1_prior = SinePrior(parameter_names=["theta_1"]) -theta_2_prior = SinePrior(parameter_names=["theta_2"]) -phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) -a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) -a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) -dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) -t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) -phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) -ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -dec_prior = CosinePrior(parameter_names=["dec"]) - -prior = CombinePrior( - [ - Mc_prior, - q_prior, - theta_jn_prior, - phi_jl_prior, - theta_1_prior, - theta_2_prior, - phi_12_prior, - a_1_prior, - a_2_prior, - dL_prior, - t_c_prior, - phase_c_prior, - psi_prior, - ra_prior, - dec_prior, - ] -) - -sample_transforms = [ - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), - BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), - BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), - BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) -] - -likelihood_transforms = [ - SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), - MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), -] - -likelihood = TransientLikelihoodFD( - ifos, - waveform=RippleIMRPhenomD(), - trigger_time=gps, - duration=4, - post_trigger_duration=2, -) - - -mass_matrix = jnp.eye(15) -mass_matrix = mass_matrix.at[1, 1].set(1e-3) -mass_matrix = mass_matrix.at[9, 9].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} - -Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) - -n_epochs = 2 -n_loop_training = 1 -learning_rate = 1e-4 - - -jim = Jim( - likelihood, - prior, - sample_transforms=sample_transforms, - likelihood_transforms=likelihood_transforms, - n_loop_training=n_loop_training, - n_loop_production=1, - n_local_steps=5, - n_global_steps=5, - n_chains=4, - n_epochs=n_epochs, - learning_rate=learning_rate, - n_max_examples=30, - n_flow_samples=100, - momentum=0.9, - batch_size=100, - use_global=True, - train_thinning=1, - output_thinning=1, - local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer, "default"], -) - -jim.sample(jax.random.PRNGKey(42)) -jim.get_samples() -jim.print_summary() From 0a9f58f38a40debd06849954cebae2b70b3dca7f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Sat, 3 Aug 2024 07:29:47 -0400 Subject: [PATCH 145/172] Create test_GW150914_Pv2.py --- test/integration/test_GW150914_Pv2.py | 142 ++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 test/integration/test_GW150914_Pv2.py diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py new file mode 100644 index 00000000..c9d83a5e --- /dev/null +++ b/test/integration/test_GW150914_Pv2.py @@ -0,0 +1,142 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) +theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) +phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) +theta_1_prior = SinePrior(parameter_names=["theta_1"]) +theta_2_prior = SinePrior(parameter_names=["theta_2"]) +phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) +a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) +a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + Mc_prior, + q_prior, + theta_jn_prior, + phi_jl_prior, + theta_1_prior, + theta_2_prior, + phi_12_prior, + a_1_prior, + a_2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) +] + +likelihood_transforms = [ + SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(15) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() From f33a7824f03de5d602b22cfbf2ec93b46e278316 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 12 Aug 2024 15:49:26 +0200 Subject: [PATCH 146/172] Adding transform from geocentric arrival time to detector arrival time --- src/jimgw/single_event/transforms.py | 54 ++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index c3e77846..b1735f5a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -170,6 +170,60 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): + """ + Transform the geocentric arrival time to detector arrival time + + In the geocentric convention, the arrival time of the signal at the + center of Earth is gps_time + t_c + + In the detector convention, the arrival time of the signal at the + detecotr is gps_time + time_delay_from_geo_to_det + t_det + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + t_c: Float + t_det: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifo: GroundBased2G, + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifo = ifo + + assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + + def named_transform(x): + t_det = x["t_c"] + self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + return { + "t_det": t_det, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + t_c = x["t_det"] - self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + return { + "t_c": t_c, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class SpinToCartesianSpinTransform(NtoNTransform): """ From 3505394c6e1f974134ecf59e326507c376c54621 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 12 Aug 2024 16:01:11 +0200 Subject: [PATCH 147/172] Adding transform from distance to SNR weighted distance --- src/jimgw/single_event/transforms.py | 88 +++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index b1735f5a..6404fb7e 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -208,7 +208,9 @@ def __init__( assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] def named_transform(x): - t_det = x["t_c"] + self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + t_det = x["t_c"] + self.ifo.delay_from_geocenter( + x["ra"], x["dec"], self.gmst + ) return { "t_det": t_det, } @@ -216,7 +218,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - t_c = x["t_det"] - self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + t_c = x["t_det"] - self.ifo.delay_from_geocenter( + x["ra"], x["dec"], self.gmst + ) return { "t_c": t_c, } @@ -224,6 +228,86 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): + """ + Transform the luminosity distance to network SNR weighted distance + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifos: list[GroundBased2G], + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifos = ifos + + assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + + def named_transform(x): + d_L, M_c, ra, dec, psi, iota = ( + x["d_L"], + x["M_c"], + x["ra"], + x["dec"], + x["psi"], + x["iota"], + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_ks2 = 0.0 + for ifo in self.ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_ks2 += p_mode_term**2 + c_mode_term**2 + R_ks = jnp.sqrt(R_ks2) + d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_ks + return { + "d_hat": d_hat, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + d_hat, M_c, ra, dec, psi, iota = ( + x["d_hat"], + x["M_c"], + x["ra"], + x["dec"], + x["psi"], + x["iota"], + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_ks2 = 0.0 + for ifo in self.ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_ks2 += p_mode_term**2 + c_mode_term**2 + R_ks = jnp.sqrt(R_ks2) + d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_ks + return { + "d_L": d_L, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class SpinToCartesianSpinTransform(NtoNTransform): """ From df75cebc636eec563bf361e10c476465c161fe38 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 10:12:45 +0200 Subject: [PATCH 148/172] updating the typing for object attributes --- src/jimgw/single_event/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 6404fb7e..e1a49e58 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -189,8 +189,7 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): """ gmst: Float - t_c: Float - t_det: Float + ifo: GroundBased2G def __init__( self, @@ -241,6 +240,7 @@ class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): """ gmst: Float + ifos: list[GroundBased2G] def __init__( self, From 9f2f52b323bd03923043a0be4203030229225f71 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 11:35:34 +0200 Subject: [PATCH 149/172] Adding geocentric phase to detector phase --- src/jimgw/single_event/transforms.py | 106 +++++++++++++++++++++------ 1 file changed, 83 insertions(+), 23 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index e1a49e58..bf060b8a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -227,6 +227,72 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): + """ + Transform the geocentric arrival phase to detector arrival phase + + In the geocentric convention, the arrival phase of the signal at the + center of Earth is phi_c / 2 (in ripple, phi_c is the orbital phase) + + In the detector convention, the arrival phase of the signal at the + detecotr is phi_det = phi_c / 2 + arg R_det + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + ifo: GroundBased2G + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifo: GroundBased2G, + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifo = ifo + + assert "phi_c" in name_mapping[0] and "phi_det" in name_mapping[1] + + def _calc_R_det(x): + ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + + return p_mode_term - 1j * c_mode_term + + def named_transform(x): + R_det = _calc_R_det(x) + phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 + return { + "phi_det": phi_det, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + R_det = _calc_R_det(x) + phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 + return { + "phi_c": phi_c, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): """ @@ -257,10 +323,8 @@ def __init__( assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] - def named_transform(x): - d_L, M_c, ra, dec, psi, iota = ( - x["d_L"], - x["M_c"], + def _calc_R_dets(x): + ra, dec, psi, iota = ( x["ra"], x["dec"], x["psi"], @@ -268,14 +332,22 @@ def named_transform(x): ) p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - R_ks2 = 0.0 + R_dets2 = 0.0 for ifo in self.ifos: antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] - R_ks2 += p_mode_term**2 + c_mode_term**2 - R_ks = jnp.sqrt(R_ks2) - d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_ks + R_dets2 += p_mode_term**2 + c_mode_term**2 + + return jnp.sqrt(R_dets2) + + def named_transform(x): + d_L, M_c = ( + x["d_L"], + x["M_c"], + ) + R_dets = _calc_R_dets(x) + d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets return { "d_hat": d_hat, } @@ -283,24 +355,12 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - d_hat, M_c, ra, dec, psi, iota = ( + d_hat, M_c = ( x["d_hat"], x["M_c"], - x["ra"], - x["dec"], - x["psi"], - x["iota"], ) - p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 - c_iota_term = jnp.cos(iota) - R_ks2 = 0.0 - for ifo in self.ifos: - antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) - p_mode_term = p_iota_term * antenna_pattern["p"] - c_mode_term = c_iota_term * antenna_pattern["c"] - R_ks2 += p_mode_term**2 + c_mode_term**2 - R_ks = jnp.sqrt(R_ks2) - d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_ks + R_dets = _calc_R_dets(x) + d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets return { "d_L": d_L, } From b62970f2a8f05bae53408870763aa115ea957fdc Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 11:51:22 +0200 Subject: [PATCH 150/172] Adding ZeroLikelihood for testing purpose --- src/jimgw/single_event/likelihood.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 00e6ce6b..9e775b33 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -26,6 +26,15 @@ def __init__(self, detectors: list[Detector], waveform: Waveform) -> None: self.waveform = waveform +class ZeroLikelihood(LikelihoodBase): + + def __init__(self): + pass + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + return 0.0 + + class TransientLikelihoodFD(SingleEventLiklihood): def __init__( self, From 4ea332224fa35363961220b1dffccc167599a6c3 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 12:09:47 +0200 Subject: [PATCH 151/172] Adding the missing mode 2pi for phasing transform --- src/jimgw/single_event/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index bf060b8a..6f4f361a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -278,7 +278,7 @@ def named_transform(x): R_det = _calc_R_det(x) phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 return { - "phi_det": phi_det, + "phi_det": phi_det % (2. * jnp.pi), } self.transform_func = named_transform @@ -287,7 +287,7 @@ def named_inverse_transform(x): R_det = _calc_R_det(x) phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 return { - "phi_c": phi_c, + "phi_c": phi_c % (2. * jnp.pi), } self.inverse_transform_func = named_inverse_transform From 7a4bae0e8d06735458b6f1e004258d1afb15aac2 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 13:38:23 +0200 Subject: [PATCH 152/172] Test wip --- test/integration/test_extrinsic.py | 108 ++++++++++++++++++ .../integration/test_extrinsic_no_distance.py | 90 +++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 test/integration/test_extrinsic.py create mode 100644 test/integration/test_extrinsic_no_distance.py diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py new file mode 100644 index 00000000..c4719dd8 --- /dev/null +++ b/test/integration/test_extrinsic.py @@ -0,0 +1,108 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import ZeroLikelihood +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 + +ifos = [H1, L1, V1] + +M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + M_c_prior, + q_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], gps_time=gps, ifos=ifos), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=2.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), +] + +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + +likelihood = ZeroLikelihood() + +mass_matrix = jnp.eye(9) +#mass_matrix = mass_matrix.at[1, 1].set(1e-3) +#mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() diff --git a/test/integration/test_extrinsic_no_distance.py b/test/integration/test_extrinsic_no_distance.py new file mode 100644 index 00000000..d1b1e559 --- /dev/null +++ b/test/integration/test_extrinsic_no_distance.py @@ -0,0 +1,90 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import ZeroLikelihood +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 + +ifos = [H1, L1, V1] + +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + t_c_prior, + phase_c_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), +] + +likelihood_transforms = [] + +likelihood = ZeroLikelihood() + +mass_matrix = jnp.eye(9) +#mass_matrix = mass_matrix.at[1, 1].set(1e-3) +#mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() From d5f86e52156c788a06d84c1ac5d0958ae4d0b8fb Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 13:42:45 +0200 Subject: [PATCH 153/172] Phase renaming --- src/jimgw/single_event/transforms.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 6f4f361a..1242b40b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -233,10 +233,10 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): Transform the geocentric arrival phase to detector arrival phase In the geocentric convention, the arrival phase of the signal at the - center of Earth is phi_c / 2 (in ripple, phi_c is the orbital phase) + center of Earth is phase_c / 2 (in ripple, phase_c is the orbital phase) In the detector convention, the arrival phase of the signal at the - detecotr is phi_det = phi_c / 2 + arg R_det + detecotr is phase_det = phase_c / 2 + arg R_det Parameters ---------- @@ -261,7 +261,7 @@ def __init__( ) self.ifo = ifo - assert "phi_c" in name_mapping[0] and "phi_det" in name_mapping[1] + assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] def _calc_R_det(x): ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] @@ -276,18 +276,18 @@ def _calc_R_det(x): def named_transform(x): R_det = _calc_R_det(x) - phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 + phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 return { - "phi_det": phi_det % (2. * jnp.pi), + "phase_det": phase_det % (2. * jnp.pi), } self.transform_func = named_transform def named_inverse_transform(x): R_det = _calc_R_det(x) - phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 + phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 return { - "phi_c": phi_c % (2. * jnp.pi), + "phase_c": phase_c % (2. * jnp.pi), } self.inverse_transform_func = named_inverse_transform From 0a2e68c22611fea72a2643bbea8ab358c67c3650 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Tue, 13 Aug 2024 06:09:13 -0700 Subject: [PATCH 154/172] wip --- src/jimgw/single_event/transforms.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 1242b40b..538d6b8e 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -208,7 +208,7 @@ def __init__( def named_transform(x): t_det = x["t_c"] + self.ifo.delay_from_geocenter( - x["ra"], x["dec"], self.gmst + x["ra"][0], x["dec"][0], self.gmst ) return { "t_det": t_det, @@ -217,8 +217,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): + import pdb; pdb.set_trace() t_c = x["t_det"] - self.ifo.delay_from_geocenter( - x["ra"], x["dec"], self.gmst + x["ra"][0], x["dec"][0], self.gmst ) return { "t_c": t_c, @@ -268,7 +269,7 @@ def _calc_R_det(x): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] @@ -278,7 +279,7 @@ def named_transform(x): R_det = _calc_R_det(x) phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 return { - "phase_det": phase_det % (2. * jnp.pi), + "phase_det": phase_det % (2.0 * jnp.pi), } self.transform_func = named_transform @@ -287,7 +288,7 @@ def named_inverse_transform(x): R_det = _calc_R_det(x) phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 return { - "phase_c": phase_c % (2. * jnp.pi), + "phase_c": phase_c % (2.0 * jnp.pi), } self.inverse_transform_func = named_inverse_transform From b96512c64139661b38fd7df1f2506d480c527ebb Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 14 Aug 2024 13:30:43 -0400 Subject: [PATCH 155/172] Push conditional bijective transform --- src/jimgw/transforms.py | 54 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 715d49de..4d2ebb45 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -61,8 +61,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: class NtoNTransform(NtoMTransform): - transform_func: Callable[[dict[str, Float]], dict[str, Float]] - @property def n_dim(self) -> int: return len(self.name_mapping[0]) @@ -162,6 +160,58 @@ def backward(self, y: dict[str, Float]) -> dict[str, Float]: list(output_params.keys()), ) return y_copy + +class ConditionalBijectiveTransform(BijectiveTransform): + + conditional_names: list[str] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], + ): + super().__init__(name_mapping) + self.conditional_names = conditional_names + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + x_copy = x.copy() + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + transform_params.update( + dict((key, x_copy[key]) for key in self.conditional_names) + ) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jax.tree.map( + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy, jacobian + + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + y_copy = y.copy() + transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) + transform_params.update( + dict((key, y_copy[key]) for key in self.conditional_names) + ) + output_params = self.inverse_transform_func(transform_params) + jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jax.tree.map( + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy, jacobian @jaxtyped(typechecker=typechecker) From 526e33cbb8eae4c32782113cd2a7495d657e5ea9 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:44:04 -0700 Subject: [PATCH 156/172] Switch to using conditional transform --- src/jimgw/single_event/transforms.py | 46 ++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 538d6b8e..2fb757c2 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,7 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import BijectiveTransform, NtoNTransform +from jimgw.transforms import ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -171,7 +171,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): +class GeocentricArrivalTimeToDetectorArrivalTimeTransform(ConditionalBijectiveTransform): """ Transform the geocentric arrival time to detector arrival time @@ -194,10 +194,11 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -205,11 +206,22 @@ def __init__( self.ifo = ifo assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + ) + + def _calc_delay(x): + ra, dec = x["ra"], x["dec"] + if hasattr(ra, "shape") and len(ra.shape) > 0: + delay = self.ifo.delay_from_geocenter(ra[0], dec[0], self.gmst) + else: + delay = self.ifo.delay_from_geocenter(ra, dec, self.gmst) + return delay def named_transform(x): - t_det = x["t_c"] + self.ifo.delay_from_geocenter( - x["ra"][0], x["dec"][0], self.gmst - ) + delay = _calc_delay(x) + t_det = x["t_c"] + delay return { "t_det": t_det, } @@ -217,10 +229,8 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - import pdb; pdb.set_trace() - t_c = x["t_det"] - self.ifo.delay_from_geocenter( - x["ra"][0], x["dec"][0], self.gmst - ) + delay = _calc_delay(x) + t_c = x["t_det"] - delay return { "t_c": t_c, } @@ -229,7 +239,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(ConditionalBijectiveTransform): """ Transform the geocentric arrival phase to detector arrival phase @@ -252,10 +262,11 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -263,13 +274,22 @@ def __init__( self.ifo = ifo assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + and "psi" in conditional_names + and "iota" in conditional_names + ) def _calc_R_det(x): ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + if hasattr(ra, "shape") and len(ra.shape) > 0: + antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + else: + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] From dbf3f3064e5135575a5c7548ee9e1c5e9157553b Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:45:40 -0700 Subject: [PATCH 157/172] Switch to using conditional transform --- src/jimgw/single_event/transforms.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 2fb757c2..dfcd01b8 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,11 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform +from jimgw.transforms import ( + ConditionalBijectiveTransform, + BijectiveTransform, + NtoNTransform, +) from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -171,7 +175,9 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalTimeToDetectorArrivalTimeTransform(ConditionalBijectiveTransform): +class GeocentricArrivalTimeToDetectorArrivalTimeTransform( + ConditionalBijectiveTransform +): """ Transform the geocentric arrival time to detector arrival time @@ -206,10 +212,7 @@ def __init__( self.ifo = ifo assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] - assert ( - "ra" in conditional_names - and "dec" in conditional_names - ) + assert "ra" in conditional_names and "dec" in conditional_names def _calc_delay(x): ra, dec = x["ra"], x["dec"] @@ -239,7 +242,9 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(ConditionalBijectiveTransform): +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( + ConditionalBijectiveTransform +): """ Transform the geocentric arrival phase to detector arrival phase @@ -287,7 +292,9 @@ def _calc_R_det(x): c_iota_term = jnp.cos(iota) if hasattr(ra, "shape") and len(ra.shape) > 0: - antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + antenna_pattern = self.ifo.antenna_pattern( + ra[0], dec[0], psi[0], self.gmst + ) else: antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] From a3753612926dd2b95b440bbc14043cdf7412df29 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:46:04 -0700 Subject: [PATCH 158/172] Fixing jacobian handling --- src/jimgw/transforms.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4d2ebb45..a26ad9af 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -160,7 +160,8 @@ def backward(self, y: dict[str, Float]) -> dict[str, Float]: list(output_params.keys()), ) return y_copy - + + class ConditionalBijectiveTransform(BijectiveTransform): conditional_names: list[str] @@ -181,8 +182,14 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: ) output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian_copy = { + key1: {key2: jacobian[key1][key2] for key2 in self.name_mapping[0]} + for key1 in self.name_mapping[1] + } + jacobian = jnp.array(jax.tree.leaves(jacobian_copy)) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -192,7 +199,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: list(output_params.keys()), ) return x_copy, jacobian - + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: y_copy = y.copy() transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) @@ -201,8 +208,14 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: ) output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian_copy = { + key1: {key2: jacobian[key1][key2] for key2 in self.name_mapping[1]} + for key1 in self.name_mapping[0] + } + jacobian = jnp.array(jax.tree.leaves(jacobian_copy)) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], From d79af97d25e29b41a3f6690bd9df4db6ef46d54f Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 02:56:38 -0700 Subject: [PATCH 159/172] Both arrival phase and time transform are fully vectorized --- src/jimgw/single_event/transforms.py | 43 ++++++++++++---------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index dfcd01b8..047c0607 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -214,17 +214,10 @@ def __init__( assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] assert "ra" in conditional_names and "dec" in conditional_names - def _calc_delay(x): - ra, dec = x["ra"], x["dec"] - if hasattr(ra, "shape") and len(ra.shape) > 0: - delay = self.ifo.delay_from_geocenter(ra[0], dec[0], self.gmst) - else: - delay = self.ifo.delay_from_geocenter(ra, dec, self.gmst) - return delay - def named_transform(x): - delay = _calc_delay(x) - t_det = x["t_c"] + delay + t_det = x["t_c"] + jnp.vectorize(self.ifo.delay_from_geocenter)( + x["ra"], x["dec"], self.gmst + ) return { "t_det": t_det, } @@ -232,8 +225,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - delay = _calc_delay(x) - t_c = x["t_det"] - delay + t_c = x["t_det"] - jnp.vectorize(self.ifo.delay_from_geocenter)( + x["ra"], x["dec"], self.gmst + ) return { "t_c": t_c, } @@ -286,25 +280,22 @@ def __init__( and "iota" in conditional_names ) - def _calc_R_det(x): - ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] + @jnp.vectorize + def _calc_R_det_arg(ra, dec, psi, iota, gmst): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - if hasattr(ra, "shape") and len(ra.shape) > 0: - antenna_pattern = self.ifo.antenna_pattern( - ra[0], dec[0], psi[0], self.gmst - ) - else: - antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] - return p_mode_term - 1j * c_mode_term + return jnp.angle(p_mode_term - 1j * c_mode_term) def named_transform(x): - R_det = _calc_R_det(x) - phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 + R_det_arg = _calc_R_det_arg( + x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + ) + phase_det = R_det_arg + x["phase_c"] / 2.0 return { "phase_det": phase_det % (2.0 * jnp.pi), } @@ -312,8 +303,10 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - R_det = _calc_R_det(x) - phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 + R_det_arg = _calc_R_det_arg( + x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + ) + phase_c = (-R_det_arg + x["phase_det"]) * 2.0 return { "phase_c": phase_c % (2.0 * jnp.pi), } From bcbcbe2e8d6e1d565544cadf9b9ea8e333602cd0 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 03:56:26 -0700 Subject: [PATCH 160/172] Shifting distance transform to conditional --- src/jimgw/single_event/transforms.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 047c0607..4b6ada1f 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -315,7 +315,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): +class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): """ Transform the luminosity distance to network SNR weighted distance @@ -332,10 +332,11 @@ class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -343,14 +344,16 @@ def __init__( self.ifos = ifos assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + and "psi" in conditional_names + and "iota" in conditional_names + and "M_c" in conditional_names + ) - def _calc_R_dets(x): - ra, dec, psi, iota = ( - x["ra"], - x["dec"], - x["psi"], - x["iota"], - ) + @jnp.vectorize + def _calc_R_dets(ra, dec, psi, iota): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) R_dets2 = 0.0 @@ -367,7 +370,7 @@ def named_transform(x): x["d_L"], x["M_c"], ) - R_dets = _calc_R_dets(x) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets return { "d_hat": d_hat, @@ -380,7 +383,7 @@ def named_inverse_transform(x): x["d_hat"], x["M_c"], ) - R_dets = _calc_R_dets(x) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets return { "d_L": d_L, From 8dab27b6e259063f03db8f381e8613ca7e34a9d4 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 03:56:48 -0700 Subject: [PATCH 161/172] update example --- test/integration/test_extrinsic.py | 62 ++++++++++--- .../integration/test_extrinsic_no_distance.py | 90 ------------------- 2 files changed, 48 insertions(+), 104 deletions(-) delete mode 100644 test/integration/test_extrinsic_no_distance.py diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index c4719dd8..988f6ea7 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -1,3 +1,12 @@ +import psutil +p = psutil.Process() +p.cpu_affinity([0]) + +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +from astropy.time import Time + import jax import jax.numpy as jnp @@ -21,7 +30,6 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) @@ -33,7 +41,6 @@ prior = CombinePrior( [ M_c_prior, - q_prior, dL_prior, t_c_prior, phase_c_prior, @@ -44,31 +51,56 @@ ] ) +# calculate the d_hat range +@jnp.vectorize +def calc_R_dets(ra, dec, psi, iota): + gmst = ( + Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_dets2 = 0.0 + for ifo in ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_dets2 += p_mode_term**2 + c_mode_term**2 + + return jnp.sqrt(R_dets2) + +key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(1234), 4) +# generate 10000 samples for each +ra_samples = ra_prior.sample(key1, 10000)["ra"] +dec_samples = dec_prior.sample(key2, 10000)["dec"] +psi_samples = psi_prior.sample(key3, 10000)["psi"] +iota_samples = iota_prior.sample(key4, 10000)["iota"] +R_dets_samples = calc_R_dets(ra_samples, dec_samples, psi_samples, iota_samples) + +d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) +d_hat_max = dL_prior.xmax / jnp.power(M_c_prior.xmin, 5. / 6.) / jnp.amin(R_dets_samples) + sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], gps_time=gps, ifos=ifos), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=2.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), + BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=d_hat_min, original_upper_bound=d_hat_max), ] -likelihood_transforms = [ - MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), -] +likelihood_transforms = [] likelihood = ZeroLikelihood() -mass_matrix = jnp.eye(9) +mass_matrix = jnp.eye(len(prior.base_prior)) #mass_matrix = mass_matrix.at[1, 1].set(1e-3) #mass_matrix = mass_matrix.at[5, 5].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 3e-3} @@ -100,9 +132,11 @@ train_thinning=1, output_thinning=1, local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer, "default"], + strategies=["default"], ) -jim.sample(jax.random.PRNGKey(42)) -jim.get_samples() +print("Start sampling") +key = jax.random.PRNGKey(42) +jim.sample(key) jim.print_summary() +samples = jim.get_samples() diff --git a/test/integration/test_extrinsic_no_distance.py b/test/integration/test_extrinsic_no_distance.py deleted file mode 100644 index d1b1e559..00000000 --- a/test/integration/test_extrinsic_no_distance.py +++ /dev/null @@ -1,90 +0,0 @@ -import jax -import jax.numpy as jnp - -from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior -from jimgw.single_event.detector import H1, L1, V1 -from jimgw.single_event.likelihood import ZeroLikelihood -from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform -from flowMC.strategy.optimization import optimization_Adam - -jax.config.update("jax_enable_x64", True) - -########################################### -########## First we grab data ############# -########################################### - -# first, fetch a 4s segment centered on GW150914 -gps = 1126259462.4 - -ifos = [H1, L1, V1] - -t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) -phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -dec_prior = CosinePrior(parameter_names=["dec"]) - -prior = CombinePrior( - [ - t_c_prior, - phase_c_prior, - ra_prior, - dec_prior, - ] -) - -sample_transforms = [ - # all the user reparametrization transform - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), - # all the bound to unbound transform - BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), -] - -likelihood_transforms = [] - -likelihood = ZeroLikelihood() - -mass_matrix = jnp.eye(9) -#mass_matrix = mass_matrix.at[1, 1].set(1e-3) -#mass_matrix = mass_matrix.at[5, 5].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} - -Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) - -n_epochs = 2 -n_loop_training = 1 -learning_rate = 1e-4 - - -jim = Jim( - likelihood, - prior, - sample_transforms=sample_transforms, - likelihood_transforms=likelihood_transforms, - n_loop_training=n_loop_training, - n_loop_production=1, - n_local_steps=5, - n_global_steps=5, - n_chains=4, - n_epochs=n_epochs, - learning_rate=learning_rate, - n_max_examples=30, - n_flow_samples=100, - momentum=0.9, - batch_size=100, - use_global=True, - train_thinning=1, - output_thinning=1, - local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer, "default"], -) - -jim.sample(jax.random.PRNGKey(42)) -jim.get_samples() -jim.print_summary() From fd338825bd10557c67da5b55bac9dc59a75db379 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 14:17:17 -0700 Subject: [PATCH 162/172] Fixing the single sided unbound transform --- src/jimgw/transforms.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index a26ad9af..8b4e2d75 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -418,17 +418,26 @@ class SingleSidedUnboundTransform(BijectiveTransform): """ + original_lower_bound: Float + def __init__( self, name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, ): super().__init__(name_mapping) + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.transform_func = lambda x: { - name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + name_mapping[1][i]: jnp.log( + x[name_mapping[0][i]] - self.original_lower_bound[i] + ) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + name_mapping[0][i]: jnp.exp( + x[name_mapping[1][i]] + self.original_lower_bound[i] + ) for i in range(len(name_mapping[1])) } From 03e76dc1731c793039504b6bf709da79953e17dc Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Sat, 17 Aug 2024 01:52:18 -0700 Subject: [PATCH 163/172] Update extrinsic test --- test/integration/test_extrinsic.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 988f6ea7..7cb5bf32 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -1,10 +1,3 @@ -import psutil -p = psutil.Process() -p.cpu_affinity([0]) - -import os -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - from astropy.time import Time import jax @@ -14,7 +7,7 @@ from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1, V1 from jimgw.single_event.likelihood import ZeroLikelihood -from jimgw.transforms import BoundToUnbound +from jimgw.transforms import BoundToUnbound, SingleSidedUnboundTransform from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform from flowMC.strategy.optimization import optimization_Adam @@ -30,7 +23,7 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(10.0, 200.0, -2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) iota_prior = SinePrior(parameter_names=["iota"]) @@ -93,7 +86,7 @@ def calc_R_dets(ra, dec, psi, iota): BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), - BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=d_hat_min, original_upper_bound=d_hat_max), + SingleSidedUnboundTransform(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=float(d_hat_min)), ] likelihood_transforms = [] @@ -119,9 +112,9 @@ def calc_R_dets(ra, dec, psi, iota): likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, - n_local_steps=5, - n_global_steps=5, - n_chains=4, + n_local_steps=1, + n_global_steps=1, + n_chains=10, n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30, From 8fe4b5fdcf43c3a2c6430ba868aeddb3e96bbc72 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 14:52:39 +0200 Subject: [PATCH 164/172] bugfix for single sided transform --- src/jimgw/transforms.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 8b4e2d75..f7d4c702 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -435,9 +435,8 @@ def __init__( for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.exp( - x[name_mapping[1][i]] + self.original_lower_bound[i] - ) + name_mapping[0][i]: jnp.exp(x[name_mapping[1][i]]) + + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } From a19b556a0db3565b141b7508ddf85f3e4250e8ef Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 05:59:11 -0700 Subject: [PATCH 165/172] Update test --- test/integration/test_extrinsic.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 7cb5bf32..55979402 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -23,7 +23,7 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -dL_prior = PowerLawPrior(10.0, 200.0, -2.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) iota_prior = SinePrior(parameter_names=["iota"]) @@ -44,33 +44,7 @@ ] ) -# calculate the d_hat range -@jnp.vectorize -def calc_R_dets(ra, dec, psi, iota): - gmst = ( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad - ) - p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 - c_iota_term = jnp.cos(iota) - R_dets2 = 0.0 - for ifo in ifos: - antenna_pattern = ifo.antenna_pattern(ra, dec, psi, gmst) - p_mode_term = p_iota_term * antenna_pattern["p"] - c_mode_term = c_iota_term * antenna_pattern["c"] - R_dets2 += p_mode_term**2 + c_mode_term**2 - - return jnp.sqrt(R_dets2) - -key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(1234), 4) -# generate 10000 samples for each -ra_samples = ra_prior.sample(key1, 10000)["ra"] -dec_samples = dec_prior.sample(key2, 10000)["dec"] -psi_samples = psi_prior.sample(key3, 10000)["psi"] -iota_samples = iota_prior.sample(key4, 10000)["iota"] -R_dets_samples = calc_R_dets(ra_samples, dec_samples, psi_samples, iota_samples) - d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) -d_hat_max = dL_prior.xmax / jnp.power(M_c_prior.xmin, 5. / 6.) / jnp.amin(R_dets_samples) sample_transforms = [ # all the user reparametrization transform From 6d2cd9704c10a8345aebfde506ed8bcbeb3e83fe Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 16:25:39 +0200 Subject: [PATCH 166/172] update distance transform --- src/jimgw/single_event/transforms.py | 36 +++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 4b6ada1f..5d8e688b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -328,6 +328,8 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): gmst: Float ifos: list[GroundBased2G] + d_L_min: Float + d_L_max: Float def __init__( self, @@ -335,6 +337,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], + d_L_min: Float, + d_L_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -342,8 +346,10 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifos = ifos + self.d_L_min = d_L_min + self.d_L_max = d_L_max - assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] assert ( "ra" in conditional_names and "dec" in conditional_names @@ -371,20 +377,38 @@ def named_transform(x): x["M_c"], ) R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) - d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets + + scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets + d_hat = scale_factor * d_L + + d_hat_min = scale_factor * self.d_L_min + d_hat_max = scale_factor * self.d_L_max + + y = (d_hat - d_hat_min) / (d_hat_max - d_hat_min) + d_hat_unbounded = jnp.log(y / (1.0 - y)) + return { - "d_hat": d_hat, + "d_hat_unbounded": d_hat_unbounded, } self.transform_func = named_transform def named_inverse_transform(x): - d_hat, M_c = ( - x["d_hat"], + d_hat_unbounded, M_c = ( + x["d_hat_unbounded"], x["M_c"], ) R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) - d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets + + scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets + + d_hat_min = scale_factor * self.d_L_min + d_hat_max = scale_factor * self.d_L_max + + d_hat = (d_hat_max - d_hat_min) / ( + 1.0 + jnp.exp(-d_hat_unbounded) + ) + d_hat_min + d_L = d_hat / scale_factor return { "d_L": d_L, } From 6993dd9072780203bf99046343f62328c63c7848 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 08:58:57 -0700 Subject: [PATCH 167/172] Update test --- test/integration/test_extrinsic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 55979402..a5dc5c7b 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -44,11 +44,10 @@ ] ) -d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), @@ -60,7 +59,6 @@ BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), - SingleSidedUnboundTransform(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=float(d_hat_min)), ] likelihood_transforms = [] From ff65fcf2e75a0655c8e33ddfb32c378c84b3dc8e Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 21:58:57 +0200 Subject: [PATCH 168/172] Update arrival time transform --- src/jimgw/single_event/transforms.py | 58 ++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 5d8e688b..1070adaf 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -196,6 +196,8 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform( gmst: Float ifo: GroundBased2G + tc_min: Float + tc_max: Float def __init__( self, @@ -203,6 +205,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, + tc_min: Float, + tc_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -210,24 +214,46 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifo = ifo + self.tc_min = tc_min + self.tc_max = tc_max - assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1] assert "ra" in conditional_names and "dec" in conditional_names + @jnp.vectorize + def time_delay(ra, dec, gmst): + return self.ifo.delay_from_geocenter(ra, dec, gmst) + def named_transform(x): - t_det = x["t_c"] + jnp.vectorize(self.ifo.delay_from_geocenter)( - x["ra"], x["dec"], self.gmst - ) + + time_shift = time_delay(x["ra"], x["dec"], self.gmst) + + t_det = x["t_c"] + time_shift + t_det_min = self.tc_min + time_shift + t_det_max = self.tc_max + time_shift + + y = (t_det - t_det_min) / (t_det_max - t_det_min) + t_det_unbounded = jnp.log(y / (1.0 - y)) return { - "t_det": t_det, + "t_det_unbounded": t_det_unbounded, } self.transform_func = named_transform def named_inverse_transform(x): - t_c = x["t_det"] - jnp.vectorize(self.ifo.delay_from_geocenter)( + + time_shift = jnp.vectorize(self.ifo.delay_from_geocenter)( x["ra"], x["dec"], self.gmst ) + + t_det_min = self.tc_min + time_shift + t_det_max = self.tc_max + time_shift + t_det = (t_det_max - t_det_min) / ( + 1.0 + jnp.exp(-x["t_det_unbounded"]) + ) + t_det_min + + t_c = t_det - time_shift + return { "t_c": t_c, } @@ -328,8 +354,8 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): gmst: Float ifos: list[GroundBased2G] - d_L_min: Float - d_L_max: Float + dL_min: Float + dL_max: Float def __init__( self, @@ -337,8 +363,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], - d_L_min: Float, - d_L_max: Float, + dL_min: Float, + dL_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -346,8 +372,8 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifos = ifos - self.d_L_min = d_L_min - self.d_L_max = d_L_max + self.dL_min = dL_min + self.dL_max = dL_max assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] assert ( @@ -381,8 +407,8 @@ def named_transform(x): scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets d_hat = scale_factor * d_L - d_hat_min = scale_factor * self.d_L_min - d_hat_max = scale_factor * self.d_L_max + d_hat_min = scale_factor * self.dL_min + d_hat_max = scale_factor * self.dL_max y = (d_hat - d_hat_min) / (d_hat_max - d_hat_min) d_hat_unbounded = jnp.log(y / (1.0 - y)) @@ -402,8 +428,8 @@ def named_inverse_transform(x): scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets - d_hat_min = scale_factor * self.d_L_min - d_hat_max = scale_factor * self.d_L_max + d_hat_min = scale_factor * self.dL_min + d_hat_max = scale_factor * self.dL_max d_hat = (d_hat_max - d_hat_min) / ( 1.0 + jnp.exp(-d_hat_unbounded) From b98d783db7b16a14c91ca83e7e3cd606697682d7 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 13:04:47 -0700 Subject: [PATCH 169/172] Update test --- test/integration/test_extrinsic.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index a5dc5c7b..ff79723e 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -49,7 +49,7 @@ # all the user reparametrization transform DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), @@ -58,7 +58,6 @@ BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), ] likelihood_transforms = [] @@ -84,8 +83,8 @@ likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, - n_local_steps=1, - n_global_steps=1, + n_local_steps=2, + n_global_steps=2, n_chains=10, n_epochs=n_epochs, learning_rate=learning_rate, From e399a5e9c7609fe1f07ba6951182e5899ac4b692 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 22:18:31 +0200 Subject: [PATCH 170/172] Fix typo --- test/integration/test_extrinsic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index ff79723e..f0e089fe 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -47,7 +47,7 @@ sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), From 583b759de0859d0d2096bc95d7cccfa1d0b6b824 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 20 Aug 2024 00:43:21 +0200 Subject: [PATCH 171/172] Fix typo --- src/jimgw/single_event/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 1070adaf..084fe368 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -332,7 +332,7 @@ def named_inverse_transform(x): R_det_arg = _calc_R_det_arg( x["ra"], x["dec"], x["psi"], x["iota"], self.gmst ) - phase_c = (-R_det_arg + x["phase_det"]) * 2.0 + phase_c = -R_det_arg + x["phase_det"] * 2.0 return { "phase_c": phase_c % (2.0 * jnp.pi), } From cd559b6ae127b8683933ddcc85776b8b27f078c6 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Thu, 22 Aug 2024 21:53:41 +0200 Subject: [PATCH 172/172] Adding docstring for zerolikelihood --- src/jimgw/single_event/likelihood.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 9e775b33..1508cdfa 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -27,6 +27,21 @@ def __init__(self, detectors: list[Detector], waveform: Waveform) -> None: class ZeroLikelihood(LikelihoodBase): + """ + A likelihood class that always returns a log-likelihood of zero. + + This class is primarily used for testing or debugging purposes. + + Methods + ------- + __init__() -> None + Initializes the ZeroLikelihood object. No parameters are required or set. + + evaluate(params: dict[str, Float], data: dict) -> Float + Evaluates the likelihood for a given set of parameters and data, + always returning 0.0. This method does not perform any computation + based on the input parameters or data, making it useful for debugging. + """ def __init__(self): pass