From 77e0247c512d0c5ed6c9d736e6015b3fb3160490 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Tue, 23 Jul 2024 13:55:23 -0400 Subject: [PATCH 001/104] Update Single_event_runManager.py --- example/Single_event_runManager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index c0878afc..88ffe52b 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -31,7 +31,6 @@ run = SingleEventRun( seed=0, - path="test_data/GW150914/", detectors=["H1", "L1"], priors={ "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0}, From 75c74a36517147f17e785cdb8dd17e2984279fac Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 10:16:11 -0400 Subject: [PATCH 002/104] scaffolding jim to handle naming. Also isort some import in prior.py --- src/jimgw/jim.py | 8 +++++++- src/jimgw/prior.py | 7 ++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 0ab1105a..15acda7c 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -13,9 +13,15 @@ class Jim(object): """ Master class for interfacing with flowMC - """ + likelihood: LikelihoodBase + prior: Prior + + seed: int + + parameter_names: list[str] + def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): self.Likelihood = likelihood self.Prior = prior diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f1827753..702dcd31 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -3,12 +3,13 @@ import jax import jax.numpy as jnp +from astropy.time import Time +from beartype import beartype as typechecker from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped -from beartype import beartype as typechecker -from jimgw.single_event.utils import zenith_azimuth_to_ra_dec + from jimgw.single_event.detector import GroundBased2G, detector_preset -from astropy.time import Time +from jimgw.single_event.utils import zenith_azimuth_to_ra_dec class Prior(Distribution): From c749dbd06b512a5c9285889f570012a229522a2e Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 10:19:32 -0400 Subject: [PATCH 003/104] Rename variables in jim --- src/jimgw/jim.py | 46 ++++++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 15acda7c..39c348fd 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -18,13 +18,19 @@ class Jim(object): likelihood: LikelihoodBase prior: Prior - seed: int - parameter_names: list[str] + sampler: Sampler - def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): - self.Likelihood = likelihood - self.Prior = prior + def __init__( + self, + likelihood: LikelihoodBase, + prior: Prior, + parameter_names: list[str], + **kwargs, + ): + self.likelihood = likelihood + self.prior = prior + self.parameter_names = parameter_names seed = kwargs.get("seed", 0) @@ -39,11 +45,11 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): rng_key, subkey = jax.random.split(rng_key) model = MaskedCouplingRQSpline( - self.Prior.n_dim, num_layers, hidden_size, num_bins, subkey + self.prior.n_dim, num_layers, hidden_size, num_bins, subkey ) self.Sampler = Sampler( - self.Prior.n_dim, + self.prior.n_dim, rng_key, None, # type: ignore local_sampler, @@ -52,15 +58,15 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): ) def posterior(self, params: Float[Array, " n_dim"], data: dict): - prior_params = self.Prior.add_name(params.T) - prior = self.Prior.log_prob(prior_params) + prior_params = self.prior.add_name(params.T) + prior = self.prior.log_prob(prior_params) return ( - self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior + self.likelihood.evaluate(self.prior.transform(prior_params), data) + prior ) def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: - initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains) + initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T self.Sampler.sample(initial_guess, None) # type: ignore @@ -73,7 +79,7 @@ def maximize_likelihood( ): key = jax.random.PRNGKey(seed) set_nwalkers = set_nwalkers - initial_guess = self.Prior.sample(key, set_nwalkers) + initial_guess = self.prior.sample(key, set_nwalkers) def negative_posterior(x: Float[Array, " n_dim"]): return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet @@ -84,7 +90,7 @@ def negative_posterior(x: Float[Array, " n_dim"]): print("Done compiling") print("Starting the optimizer") - optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True) + optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True) _ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return best_fit @@ -98,19 +104,19 @@ def print_summary(self, transform: bool = True): train_summary = self.Sampler.get_sampler_state(training=True) production_summary = self.Sampler.get_sampler_state(training=False) - training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T - training_chain = self.Prior.add_name(training_chain) + training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T + training_chain = self.prior.add_name(training_chain) if transform: - training_chain = self.Prior.transform(training_chain) + training_chain = self.prior.transform(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T - production_chain = self.Prior.add_name(production_chain) + production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T + production_chain = self.prior.add_name(production_chain) if transform: - production_chain = self.Prior.transform(production_chain) + production_chain = self.prior.transform(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -166,7 +172,7 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.Sampler.get_sampler_state(training=False)["chains"] - chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1))) + chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1))) return chains def plot(self): From f45b41e83822888adba6ea358794e832c7db0952 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 11:15:58 -0400 Subject: [PATCH 004/104] Starting tracking names in jim --- src/jimgw/jim.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 39c348fd..08b82b05 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -18,6 +18,7 @@ class Jim(object): likelihood: LikelihoodBase prior: Prior + # Name of parameters to sample from parameter_names: list[str] sampler: Sampler @@ -57,12 +58,22 @@ def __init__( **kwargs, ) + def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: + """ + Turn an array into a dictionary + + Parameters + ---------- + x : Array + An array of parameters. Shape (n_dim,). + """ + + return dict(zip(self.parameter_names, x)) + def posterior(self, params: Float[Array, " n_dim"], data: dict): - prior_params = self.prior.add_name(params.T) - prior = self.prior.log_prob(prior_params) - return ( - self.likelihood.evaluate(self.prior.transform(prior_params), data) + prior - ) + named_params = self.add_name(params) + prior = self.prior.log_prob(named_params) + return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: From dd6e0de37c6d4ca63aa55bd77b7108e8ca51cef0 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 13:53:00 -0400 Subject: [PATCH 005/104] Add Transform class --- src/jimgw/transforms.py | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/jimgw/transforms.py diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py new file mode 100644 index 00000000..0795d81d --- /dev/null +++ b/src/jimgw/transforms.py @@ -0,0 +1,63 @@ +from dataclasses import field +from typing import Callable, Union + +import jax +import jax.numpy as jnp +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped + + +class Transform: + """ + Base class for transform. + + The idea of transform should be used on distribtuion, + """ + + transform_func: Callable[[Float[Array, " N"]], Float[Array, " M"]] + jacobian_func: Callable[[Float[Array, " N"]], Float] + name_mapping: tuple[list[str], list[str]] + + def __init__( + self, + transform_func: Callable[[Float[Array, " N"]], Float[Array, " M"]], + name_mapping: tuple[list[str], list[str]], + ): + self.transform_func = transform_func + self.jacobian_func = jax.jacfwd(transform_func) + self.name_mapping = name_mapping + + def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + """ + Transform the input x to transformed coordinate y and return the log Jacobian determinant. + This only works if the transform is a N -> N transform. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + log_det : Float + The log Jacobian determinant. + """ + return self.transform_func(x), jnp.log(jnp.linalg.det(self.jacobian_func(x))) + + def push_forward(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Push forward the input x to transformed coordinate y. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + return self.transform_func(x) \ No newline at end of file From f11f8c889d69caa2dc0b5a1dd2bcc58857bb32e0 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 14:42:34 -0400 Subject: [PATCH 006/104] Add LogitToUniform Transform --- src/jimgw/transforms.py | 64 ++++++++++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 0795d81d..a261a7fe 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,33 +1,35 @@ +from abc import ABC, abstractmethod from dataclasses import field from typing import Callable, Union +from chex import assert_rank import jax import jax.numpy as jnp from beartype import beartype as typechecker -from jaxtyping import Float, Array, jaxtyped +from jaxtyping import Array, Float, jaxtyped -class Transform: +class Transform(ABC): """ Base class for transform. The idea of transform should be used on distribtuion, """ - transform_func: Callable[[Float[Array, " N"]], Float[Array, " M"]] - jacobian_func: Callable[[Float[Array, " N"]], Float] name_mapping: tuple[list[str], list[str]] + transform_func: Callable[[dict[str, Float]], dict[str, Float]] def __init__( self, - transform_func: Callable[[Float[Array, " N"]], Float[Array, " M"]], name_mapping: tuple[list[str], list[str]], ): - self.transform_func = transform_func - self.jacobian_func = jax.jacfwd(transform_func) self.name_mapping = name_mapping def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + return self.transform(x) + + @abstractmethod + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. This only works if the transform is a N -> N transform. @@ -44,9 +46,10 @@ def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - return self.transform_func(x), jnp.log(jnp.linalg.det(self.jacobian_func(x))) - - def push_forward(self, x: dict[str, Float]) -> dict[str, Float]: + raise NotImplementedError + + @abstractmethod + def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ Push forward the input x to transformed coordinate y. @@ -60,4 +63,43 @@ def push_forward(self, x: dict[str, Float]) -> dict[str, Float]: y : dict[str, Float] The transformed dictionary. """ - return self.transform_func(x) \ No newline at end of file + raise NotImplementedError + +class LogitToUniform(Transform): + """ + Transform from unconstrained space to uniform space. + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + bounds : tuple[Float, Float] + The lower and upper bounds of the uniform distribution. + + """ + + bounds: tuple[Float, Float] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + bounds: tuple[Float, Float], + ): + super().__init__(name_mapping) + self.bounds = bounds + self.transform_func = lambda x: (self.bounds[1] - self.bounds[0]) / (1 + jnp.exp(-x)) + self.bounds[0] + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + input_params = x.pop(self.name_mapping[0][0]) + assert_rank(input_params, 0) + output_params = self.transform_func(input_params) + jacobian = jax.jacfwd(self.transform_func)(input_params) + x[self.name_mapping[1][0]] = output_params + return x, jnp.log(jacobian) + + def forward(self, x: dict[str, Float]) -> dict[str, Float]: + input_params = x.pop(self.name_mapping[0][0]) + assert_rank(input_params, 0) + output_params = self.transform_func(input_params) + x[self.name_mapping[1][0]] = output_params + return x \ No newline at end of file From c1192f2c3a1031ea0b6a68533c27394b03a96732 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:24:11 -0400 Subject: [PATCH 007/104] Add logit distribution --- src/jimgw/prior.py | 79 +++++++++++++++++++++-------------------- src/jimgw/transforms.py | 2 +- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 702dcd31..e7b1acda 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -23,57 +23,19 @@ class Prior(Distribution): """ naming: list[str] - transforms: dict[str, tuple[str, Callable]] = field(default_factory=dict) @property def n_dim(self): return len(self.naming) - def __init__( - self, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {} - ): + def __init__(self, naming: list[str]): """ Parameters ---------- naming : list[str] A list of names for the parameters of the prior. - transforms : dict[tuple[str,Callable]] - A dictionary of transforms to apply to the parameters. The keys are - the names of the parameters and the values are a tuple of the name - of the transform and the transform itself. """ self.naming = naming - self.transforms = {} - - def make_lambda(name): - return lambda x: x[name] - - for name in naming: - if name in transforms: - self.transforms[name] = transforms[name] - else: - # Without the function, the lambda will refer to the variable name instead of its value, - # which will make lambda reference the last value of the variable name - self.transforms[name] = (name, make_lambda(name)) - - def transform(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Apply the transforms to the parameters. - - Parameters - ---------- - x : dict - A dictionary of parameters. Names should match the ones in the prior. - - Returns - ------- - x : dict - A dictionary of parameters with the transforms applied. - """ - output = {} - for value in self.transforms.values(): - output[value[0]] = value[1](x) - return output def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ @@ -96,6 +58,45 @@ def log_prob(self, x: dict[str, Array]) -> Float: raise NotImplementedError +@jaxtyped(typechecker=typechecker) +class Logit(Prior): + + def __repr__(self): + return f"Logit(naming={self.naming})" + + def __init__(self, naming: list[str], **kwargs): + super().__init__(naming) + assert self.n_dim == 1, "Logit needs to be 1D distributions" + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a logit distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) + samples = jnp.log(samples / (1 - samples)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Float]) -> Float: + variable = x[self.naming[0]] + return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) + + +# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) class Uniform(Prior): xmin: float = 0.0 diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index a261a7fe..d8402cfe 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -2,10 +2,10 @@ from dataclasses import field from typing import Callable, Union -from chex import assert_rank import jax import jax.numpy as jnp from beartype import beartype as typechecker +from chex import assert_rank from jaxtyping import Array, Float, jaxtyped From b3ebc5efd21f958aa4d931a4954f48d2828adcab Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:42:35 -0400 Subject: [PATCH 008/104] Add propagate)name method in transform --- src/jimgw/transforms.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8402cfe..0ee32927 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -27,13 +27,13 @@ def __init__( def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: return self.transform(x) - + @abstractmethod def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. This only works if the transform is a N -> N transform. - + Parameters ---------- x : dict[str, Float] @@ -63,7 +63,14 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: y : dict[str, Float] The transformed dictionary. """ - raise NotImplementedError + raise NotImplementedError + + def propagate_name(self, x: list[str]) -> list[str]: + input_set = set(x) + from_set = set(self.name_mapping[0]) + to_set = set(self.name_mapping[1]) + return list(input_set - from_set | to_set) + class LogitToUniform(Transform): """ @@ -87,7 +94,10 @@ def __init__( ): super().__init__(name_mapping) self.bounds = bounds - self.transform_func = lambda x: (self.bounds[1] - self.bounds[0]) / (1 + jnp.exp(-x)) + self.bounds[0] + self.transform_func = ( + lambda x: (self.bounds[1] - self.bounds[0]) / (1 + jnp.exp(-x)) + + self.bounds[0] + ) def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: input_params = x.pop(self.name_mapping[0][0]) @@ -102,4 +112,4 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: assert_rank(input_params, 0) output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params - return x \ No newline at end of file + return x From 95ef89b0fc466e1efdf15f9a4f0a8284a434343d Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:43:00 -0400 Subject: [PATCH 009/104] Rename composite to Combine for combining priors --- src/jimgw/prior.py | 118 +++++++++++++++++++++++++++------------------ 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e7b1acda..b257ef38 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,6 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec +from jimgw.transforms import Transform class Prior(Distribution): @@ -22,20 +23,20 @@ class Prior(Distribution): the names of the parameters and the transforms that are applied to them. """ - naming: list[str] + parameter_names: list[str] @property def n_dim(self): - return len(self.naming) + return len(self.parameter_names) - def __init__(self, naming: list[str]): + def __init__(self, parameter_names: list[str]): """ Parameters ---------- - naming : list[str] + parameter_names : list[str] A list of names for the parameters of the prior. """ - self.naming = naming + self.parameter_names = parameter_names def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ @@ -47,8 +48,8 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: An array of parameters. Shape (n_dim,). """ - return dict(zip(self.naming, x)) - + return dict(zip(self.parameter_names, x)) + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -62,10 +63,10 @@ def log_prob(self, x: dict[str, Array]) -> Float: class Logit(Prior): def __repr__(self): - return f"Logit(naming={self.naming})" + return f"Logit(parameter_names={self.parameter_names})" - def __init__(self, naming: list[str], **kwargs): - super().__init__(naming) + def __init__(self, parameter_names: list[str], **kwargs): + super().__init__(parameter_names) assert self.n_dim == 1, "Logit needs to be 1D distributions" def sample( @@ -92,10 +93,68 @@ def sample( return self.add_name(samples[None]) def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] + variable = x[self.parameter_names[0]] return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) +class Sequential(Prior): + """ + A prior class constructed + """ + + members: list[Prior | Transform] = field(default_factory=list) + + def __repr__(self): + return ( + f"Sequential(priors={self.members}, parameter_names={self.parameter_names})" + ) + + def __init( + self, + members: list[Prior | Transform], + ): + self.members = members + +class Combine(Prior): + """ + A prior class constructed by joinning multiple priors together to form a multivariate prior. + This assumes the priors composing the Combine class are independent. + """ + + priors: list[Prior] = field(default_factory=list) + + def __repr__(self): + return ( + f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" + ) + + def __init__( + self, + priors: list[Prior], + **kwargs, + ): + parameter_names = [] + for prior in priors: + parameter_names += prior.parameter_names + self.priors = priors + self.parameter_names = parameter_names + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + output = {} + for prior in self.priors: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output + + def log_prob(self, x: dict[str, Float]) -> Float: + output = 0.0 + for prior in self.priors: + output += prior.log_prob(x) + return output + + # ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) class Uniform(Prior): @@ -686,40 +745,3 @@ def log_prob(self, x: dict[str, Array]) -> Float: - 0.5 * ((variable - self.mean) / self.std) ** 2 ) return output - - -class Composite(Prior): - priors: list[Prior] = field(default_factory=list) - - def __repr__(self): - return f"Composite(priors={self.priors}, naming={self.naming})" - - def __init__( - self, - priors: list[Prior], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - naming = [] - self.transforms = {} - for prior in priors: - naming += prior.naming - self.transforms.update(prior.transforms) - self.priors = priors - self.naming = naming - self.transforms.update(transforms) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - output = {} - for prior in self.priors: - rng_key, subkey = jax.random.split(rng_key) - output.update(prior.sample(subkey, n_samples)) - return output - - def log_prob(self, x: dict[str, Float]) -> Float: - output = 0.0 - for prior in self.priors: - output += prior.log_prob(x) - return output From e6e45efc00b1790a4e0d8ef087aab51bf627472d Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:46:19 -0400 Subject: [PATCH 010/104] scaffold prior test --- test/test_prior.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/test_prior.py b/test/test_prior.py index 53b065da..98b8b1c4 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -1 +1,16 @@ -from jimgw.prior import Composite, Unconstrained_Uniform, Uniform +from jimgw.prior import * + + +class TestUnivariatePrior: + def test_logit(self): + p = Logit() + +class TestPriorOperations: + def test_combine(self): + raise NotImplementedError + + def test_sequence(self): + raise NotImplementedError + + def test_factor(self): + raise NotImplementedError \ No newline at end of file From d503a37e6089f395ffdfa934ee6cd9f107e66aa8 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 15:49:02 -0400 Subject: [PATCH 011/104] Add Sequential Transform --- src/jimgw/prior.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index b257ef38..3dbd6491 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -97,12 +97,13 @@ def log_prob(self, x: dict[str, Float]) -> Float: return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) -class Sequential(Prior): +class SequentialTransform(Prior): """ - A prior class constructed + Transform a prior distribution by applying a sequence of transforms. """ - members: list[Prior | Transform] = field(default_factory=list) + base_prior: Prior + transforms: list[Transform] def __repr__(self): return ( @@ -111,9 +112,30 @@ def __repr__(self): def __init( self, - members: list[Prior | Transform], + base_prior: Prior, + transforms: list[Transform], ): - self.members = members + + self.base_prior = base_prior + self.transforms = transforms + self.parameter_names = base_prior.parameter_names + for transform in transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + output = self.base_prior.sample(rng_key, n_samples) + for transform in self.transforms: + output, _ = transform.forward(output) + return output + + def log_prob(self, x: dict[str, Float]) -> Float: + output = self.base_prior.log_prob(x) + for transform in self.transforms: + _, log_jacobian = transform.transform(x) + output += log_jacobian + return output class Combine(Prior): """ From 61b3b9ec885fac94f4545cb748c64f5e07deb311 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 16:03:26 -0400 Subject: [PATCH 012/104] Separate logit and scaling. Also rename prior --- src/jimgw/prior.py | 2 +- src/jimgw/transforms.py | 75 ++++++++++++++++++++++++++++++----------- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3dbd6491..30442662 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -60,7 +60,7 @@ def log_prob(self, x: dict[str, Array]) -> Float: @jaxtyped(typechecker=typechecker) -class Logit(Prior): +class LogisticDistribution(Prior): def __repr__(self): return f"Logit(parameter_names={self.parameter_names})" diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 0ee32927..a8d0a908 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -72,32 +72,13 @@ def propagate_name(self, x: list[str]) -> list[str]: return list(input_set - from_set | to_set) -class LogitToUniform(Transform): - """ - Transform from unconstrained space to uniform space. - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - bounds : tuple[Float, Float] - The lower and upper bounds of the uniform distribution. - - """ - - bounds: tuple[Float, Float] +class UnivariateTransform(Transform): def __init__( self, name_mapping: tuple[list[str], list[str]], - bounds: tuple[Float, Float], ): super().__init__(name_mapping) - self.bounds = bounds - self.transform_func = ( - lambda x: (self.bounds[1] - self.bounds[0]) / (1 + jnp.exp(-x)) - + self.bounds[0] - ) def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: input_params = x.pop(self.name_mapping[0][0]) @@ -113,3 +94,57 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x + + +class ScaleToRange(UnivariateTransform): + + range: tuple[Float, Float] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + range: tuple[Float, Float], + ): + super().__init__(name_mapping) + self.range = range + self.transform_func = ( + lambda x: (self.range[1] - self.range[0]) * x + self.range[0] + ) + + +class Logit(Transform): + """ + Logit transform following + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + + +class Sine(Transform): + """ + Transform from unconstrained space to uniform space. + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: jnp.sin(x) From b47e6aed93c2abe1777aa5a6c8a150164d14e64d Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 16:10:08 -0400 Subject: [PATCH 013/104] Univeriate Transform seems working --- src/jimgw/prior.py | 8 ++++---- src/jimgw/transforms.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 30442662..a9274259 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -63,7 +63,7 @@ def log_prob(self, x: dict[str, Array]) -> Float: class LogisticDistribution(Prior): def __repr__(self): - return f"Logit(parameter_names={self.parameter_names})" + return f"Logistic(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) @@ -107,10 +107,10 @@ class SequentialTransform(Prior): def __repr__(self): return ( - f"Sequential(priors={self.members}, parameter_names={self.parameter_names})" + f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" ) - def __init( + def __init__( self, base_prior: Prior, transforms: list[Transform], @@ -127,7 +127,7 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) for transform in self.transforms: - output, _ = transform.forward(output) + output = jax.vmap(transform.forward)(output) return output def log_prob(self, x: dict[str, Float]) -> Float: diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index a8d0a908..52ec43be 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -112,7 +112,7 @@ def __init__( ) -class Logit(Transform): +class Logit(UnivariateTransform): """ Logit transform following @@ -131,7 +131,7 @@ def __init__( self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) -class Sine(Transform): +class Sine(UnivariateTransform): """ Transform from unconstrained space to uniform space. From e1cc408c2400807d35492a875e5ba1d10823e382 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 16:37:18 -0400 Subject: [PATCH 014/104] Uniform prior now working. Also add prior test --- src/jimgw/prior.py | 133 ++++++++++++++++++---------------------- src/jimgw/transforms.py | 22 ++++--- test/test_prior.py | 8 ++- 3 files changed, 80 insertions(+), 83 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index a9274259..a7a81acc 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform +from jimgw.transforms import Transform, Logit, Scale, Offset class Prior(Distribution): @@ -134,104 +134,87 @@ def log_prob(self, x: dict[str, Float]) -> Float: output = self.base_prior.log_prob(x) for transform in self.transforms: _, log_jacobian = transform.transform(x) - output += log_jacobian + output -= log_jacobian return output -class Combine(Prior): - """ - A prior class constructed by joinning multiple priors together to form a multivariate prior. - This assumes the priors composing the Combine class are independent. - """ +# class Combine(Prior): +# """ +# A prior class constructed by joinning multiple priors together to form a multivariate prior. +# This assumes the priors composing the Combine class are independent. +# """ + +# priors: list[Prior] = field(default_factory=list) + +# def __repr__(self): +# return ( +# f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" +# ) + +# def __init__( +# self, +# priors: list[Prior], +# **kwargs, +# ): +# parameter_names = [] +# for prior in priors: +# parameter_names += prior.parameter_names +# self.priors = priors +# self.parameter_names = parameter_names + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# output = {} +# for prior in self.priors: +# rng_key, subkey = jax.random.split(rng_key) +# output.update(prior.sample(subkey, n_samples)) +# return output + +# def log_prob(self, x: dict[str, Float]) -> Float: +# output = 0.0 +# for prior in self.priors: +# output -= prior.log_prob(x) +# return output - priors: list[Prior] = field(default_factory=list) - def __repr__(self): - return ( - f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" - ) - def __init__( - self, - priors: list[Prior], - **kwargs, - ): - parameter_names = [] - for prior in priors: - parameter_names += prior.parameter_names - self.priors = priors - self.parameter_names = parameter_names - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - output = {} - for prior in self.priors: - rng_key, subkey = jax.random.split(rng_key) - output.update(prior.sample(subkey, n_samples)) - return output - - def log_prob(self, x: dict[str, Float]) -> Float: - output = 0.0 - for prior in self.priors: - output += prior.log_prob(x) - return output - - -# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) class Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 + _dist: Prior + + xmin: float + xmax: float def __repr__(self): return f"Uniform(xmin={self.xmin}, xmax={self.xmax})" def __init__( self, - xmin: Float, - xmax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, + xmin: float, + xmax: float, + parameter_names: list[str], ): - super().__init__(naming, transforms) + super().__init__(parameter_names) assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin + self._dist = SequentialTransform( + LogisticDistribution(parameter_names), + [ + Logit((parameter_names, parameter_names)), + Scale((parameter_names, parameter_names), xmax - xmin), + Offset((parameter_names, parameter_names), xmin), + ]) def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a uniform distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - samples = jax.random.uniform( - rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax - ) - return self.add_name(samples[None]) + return self._dist.sample(rng_key, n_samples) def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - return output + jnp.log(1.0 / (self.xmax - self.xmin)) + return self._dist.log_prob(x) +# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) class Unconstrained_Uniform(Prior): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 52ec43be..203f6bf1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -95,21 +95,29 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: x[self.name_mapping[1][0]] = output_params return x +class Scale(UnivariateTransform): + scale: Float -class ScaleToRange(UnivariateTransform): + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + scale: Float, + ): + super().__init__(name_mapping) + self.scale = scale + self.transform_func = lambda x: x * self.scale - range: tuple[Float, Float] +class Offset(UnivariateTransform): + offset: Float def __init__( self, name_mapping: tuple[list[str], list[str]], - range: tuple[Float, Float], + offset: Float, ): super().__init__(name_mapping) - self.range = range - self.transform_func = ( - lambda x: (self.range[1] - self.range[0]) * x + self.range[0] - ) + self.offset = offset + self.transform_func = lambda x: x + self.offset class Logit(UnivariateTransform): diff --git a/test/test_prior.py b/test/test_prior.py index 98b8b1c4..6890f9b9 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -2,9 +2,15 @@ class TestUnivariatePrior: - def test_logit(self): + def test_logistic(self): p = Logit() + def test_uniform(self): + p = Uniform(0.0, 10.0, ['x']) + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.allclose(log_prob, -jnp.log(10.0)) + class TestPriorOperations: def test_combine(self): raise NotImplementedError From d68cb366e3d8b3ce22aa0dc8f2bfb1f412eb747a Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 17:10:29 -0400 Subject: [PATCH 015/104] Added inverse transform and uniform now perform correct --- src/jimgw/prior.py | 83 +++-------------------------------------- src/jimgw/transforms.py | 56 ++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 80 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index a7a81acc..35722707 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -131,10 +131,11 @@ def sample( return output def log_prob(self, x: dict[str, Float]) -> Float: - output = self.base_prior.log_prob(x) - for transform in self.transforms: - _, log_jacobian = transform.transform(x) - output -= log_jacobian + output = 0.0 + for transform in reversed(self.transforms): + x, log_jacobian = transform.inverse_transform(x) + output += log_jacobian + output += self.base_prior.log_prob(x) return output # class Combine(Prior): @@ -216,80 +217,6 @@ def log_prob(self, x: dict[str, Array]) -> Float: # ====================== Things below may need rework ====================== -@jaxtyped(typechecker=typechecker) -class Unconstrained_Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 - - def __repr__(self): - return f"Unconstrained_Uniform(xmin={self.xmin}, xmax={self.xmax})" - - def __init__( - self, - xmin: Float, - xmax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - local_transform = self.transforms - - def new_transform(param): - param[self.naming[0]] = self.to_range(param[self.naming[0]]) - return local_transform[self.naming[0]][1](param) - - self.transforms = { - self.naming[0]: (local_transform[self.naming[0]][0], new_transform) - } - - def to_range(self, x: Float) -> Float: - """ - Transform the parameters to the range of the prior. - - Parameters - ---------- - x : Float - The parameters to transform. - - Returns - ------- - x : dict - A dictionary of parameters with the transforms applied. - """ - return (self.xmax - self.xmin) / (1 + jnp.exp(-x)) + self.xmin - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a uniform distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : - An array of shape (n_samples, n_dim) containing the samples. - - """ - samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) - samples = jnp.log(samples / (1 - samples)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) - - class Sphere(Prior): """ A prior on a sphere represented by Cartesian coordinates. diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 203f6bf1..c5e9120a 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -18,6 +18,7 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] + inverse_func: Callable[[dict[str, Float]], dict[str, Float]] def __init__( self, @@ -47,6 +48,23 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError + + @abstractmethod + def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Inverse transform the input x to transformed coordinate y. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + raise NotImplementedError @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: @@ -64,6 +82,23 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: The transformed dictionary. """ raise NotImplementedError + + @abstractmethod + def backward(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Pull back the input x to transformed coordinate y. + + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + raise NotImplementedError def propagate_name(self, x: list[str]) -> list[str]: input_set = set(x) @@ -88,12 +123,28 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: x[self.name_mapping[1][0]] = output_params return x, jnp.log(jacobian) + def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]: + output_params = x.pop(self.name_mapping[1][0]) + assert_rank(output_params, 0) + input_params = self.inverse_func(output_params) + jacobian = jax.jacfwd(self.inverse_func)(output_params) + x[self.name_mapping[0][0]] = input_params + return x, jnp.log(jacobian) + def forward(self, x: dict[str, Float]) -> dict[str, Float]: input_params = x.pop(self.name_mapping[0][0]) assert_rank(input_params, 0) output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x + + def backward(self, x: dict[str, Float]) -> dict[str, Float]: + output_params = x.pop(self.name_mapping[1][0]) + assert_rank(output_params, 0) + input_params = self.inverse_func(output_params) + x[self.name_mapping[0][0]] = input_params + return x + class Scale(UnivariateTransform): scale: Float @@ -106,6 +157,7 @@ def __init__( super().__init__(name_mapping) self.scale = scale self.transform_func = lambda x: x * self.scale + self.inverse_func = lambda x: x / self.scale class Offset(UnivariateTransform): offset: Float @@ -118,7 +170,7 @@ def __init__( super().__init__(name_mapping) self.offset = offset self.transform_func = lambda x: x + self.offset - + self.inverse_func = lambda x: x - self.offset class Logit(UnivariateTransform): """ @@ -137,7 +189,7 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) - + self.inverse_func = lambda x: jnp.log(x / (1 - x)) class Sine(UnivariateTransform): """ From c16eef5bec5b04aa7b2bddc8788dabfd6f801c06 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 24 Jul 2024 17:13:11 -0400 Subject: [PATCH 016/104] Add comments to current sequential transform prior class --- src/jimgw/prior.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 35722707..9fee29d9 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -131,6 +131,10 @@ def sample( return output def log_prob(self, x: dict[str, Float]) -> Float: + """ + Requiring inverse transform in log_prob may not be the best option, + may need alternative + """ output = 0.0 for transform in reversed(self.transforms): x, log_jacobian = transform.inverse_transform(x) From 8ab92a481aac134ecef29883ea04eae725242fb3 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 13:13:34 -0400 Subject: [PATCH 017/104] Removing inverse. Inverse limts the type of transform one can use, And it doesn't seem to have case that will require log_prob on target space --- src/jimgw/prior.py | 14 +++++------ src/jimgw/transforms.py | 56 ----------------------------------------- 2 files changed, 7 insertions(+), 63 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 9fee29d9..61202609 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -132,14 +132,14 @@ def sample( def log_prob(self, x: dict[str, Float]) -> Float: """ - Requiring inverse transform in log_prob may not be the best option, - may need alternative + log_prob has to be evaluated in the space of the base_prior. + + """ - output = 0.0 - for transform in reversed(self.transforms): - x, log_jacobian = transform.inverse_transform(x) - output += log_jacobian - output += self.base_prior.log_prob(x) + output = self.base_prior.log_prob(x) + for transform in self.transforms: + x, log_jacobian = transform.transform(x) + output -= log_jacobian return output # class Combine(Prior): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index c5e9120a..7ef8560a 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -8,7 +8,6 @@ from chex import assert_rank from jaxtyping import Array, Float, jaxtyped - class Transform(ABC): """ Base class for transform. @@ -18,8 +17,6 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] - inverse_func: Callable[[dict[str, Float]], dict[str, Float]] - def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -49,23 +46,6 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ raise NotImplementedError - @abstractmethod - def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Inverse transform the input x to transformed coordinate y. - - Parameters - ---------- - x : dict[str, Float] - The input dictionary. - - Returns - ------- - y : dict[str, Float] - The transformed dictionary. - """ - raise NotImplementedError - @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -82,23 +62,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: The transformed dictionary. """ raise NotImplementedError - - @abstractmethod - def backward(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Pull back the input x to transformed coordinate y. - - Parameters - ---------- - x : dict[str, Float] - The input dictionary. - - Returns - ------- - y : dict[str, Float] - The transformed dictionary. - """ - raise NotImplementedError def propagate_name(self, x: list[str]) -> list[str]: input_set = set(x) @@ -123,14 +86,6 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: x[self.name_mapping[1][0]] = output_params return x, jnp.log(jacobian) - def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]: - output_params = x.pop(self.name_mapping[1][0]) - assert_rank(output_params, 0) - input_params = self.inverse_func(output_params) - jacobian = jax.jacfwd(self.inverse_func)(output_params) - x[self.name_mapping[0][0]] = input_params - return x, jnp.log(jacobian) - def forward(self, x: dict[str, Float]) -> dict[str, Float]: input_params = x.pop(self.name_mapping[0][0]) assert_rank(input_params, 0) @@ -138,14 +93,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: x[self.name_mapping[1][0]] = output_params return x - def backward(self, x: dict[str, Float]) -> dict[str, Float]: - output_params = x.pop(self.name_mapping[1][0]) - assert_rank(output_params, 0) - input_params = self.inverse_func(output_params) - x[self.name_mapping[0][0]] = input_params - return x - - class Scale(UnivariateTransform): scale: Float @@ -157,7 +104,6 @@ def __init__( super().__init__(name_mapping) self.scale = scale self.transform_func = lambda x: x * self.scale - self.inverse_func = lambda x: x / self.scale class Offset(UnivariateTransform): offset: Float @@ -170,7 +116,6 @@ def __init__( super().__init__(name_mapping) self.offset = offset self.transform_func = lambda x: x + self.offset - self.inverse_func = lambda x: x - self.offset class Logit(UnivariateTransform): """ @@ -189,7 +134,6 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) - self.inverse_func = lambda x: jnp.log(x / (1 - x)) class Sine(UnivariateTransform): """ From d8b2d1feb0bdae6b4fd59bef8cc345eca2bb8c33 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 13:48:17 -0400 Subject: [PATCH 018/104] Add transformation function --- src/jimgw/prior.py | 19 +++++++++++++++---- test/test_prior.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 61202609..232b4a0f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -126,9 +126,7 @@ def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) - for transform in self.transforms: - output = jax.vmap(transform.forward)(output) - return output + return jax.vmap(self.transform)(output) def log_prob(self, x: dict[str, Float]) -> Float: """ @@ -141,6 +139,11 @@ def log_prob(self, x: dict[str, Float]) -> Float: x, log_jacobian = transform.transform(x) output -= log_jacobian return output + + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + for transform in self.transforms: + x = transform.forward(x) + return x # class Combine(Prior): # """ @@ -185,7 +188,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped(typechecker=typechecker) class Uniform(Prior): - _dist: Prior + _dist: SequentialTransform xmin: float xmax: float @@ -218,6 +221,14 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: return self._dist.log_prob(x) + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self._dist.base_prior.sample(rng_key, n_samples) + + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + return self._dist.transform(x) # ====================== Things below may need rework ====================== diff --git a/test/test_prior.py b/test/test_prior.py index 6890f9b9..b6eb7c87 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -7,7 +7,7 @@ def test_logistic(self): def test_uniform(self): p = Uniform(0.0, 10.0, ['x']) - samples = p.sample(jax.random.PRNGKey(0), 10000) + samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, -jnp.log(10.0)) From ca1d6b6c1fa2ac7cae0d9a313693ea97af070fe4 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 14:07:16 -0400 Subject: [PATCH 019/104] Combine should be working now --- src/jimgw/prior.py | 75 +++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 232b4a0f..85414064 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -145,44 +145,43 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]: x = transform.forward(x) return x -# class Combine(Prior): -# """ -# A prior class constructed by joinning multiple priors together to form a multivariate prior. -# This assumes the priors composing the Combine class are independent. -# """ - -# priors: list[Prior] = field(default_factory=list) - -# def __repr__(self): -# return ( -# f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" -# ) - -# def __init__( -# self, -# priors: list[Prior], -# **kwargs, -# ): -# parameter_names = [] -# for prior in priors: -# parameter_names += prior.parameter_names -# self.priors = priors -# self.parameter_names = parameter_names - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# output = {} -# for prior in self.priors: -# rng_key, subkey = jax.random.split(rng_key) -# output.update(prior.sample(subkey, n_samples)) -# return output - -# def log_prob(self, x: dict[str, Float]) -> Float: -# output = 0.0 -# for prior in self.priors: -# output -= prior.log_prob(x) -# return output +class Combine(Prior): + """ + A prior class constructed by joinning multiple priors together to form a multivariate prior. + This assumes the priors composing the Combine class are independent. + """ + + priors: list[Prior] = field(default_factory=list) + + def __repr__(self): + return ( + f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" + ) + + def __init__( + self, + priors: list[Prior], + ): + parameter_names = [] + for prior in priors: + parameter_names += prior.parameter_names + self.priors = priors + self.parameter_names = parameter_names + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + output = {} + for prior in self.priors: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + return output + + def log_prob(self, x: dict[str, Float]) -> Float: + output = 0.0 + for prior in self.priors: + output += prior.log_prob(x) + return output From 2f3f412446ec346a8dbfe367dc7acfee7bf4c7d1 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 14:52:02 -0400 Subject: [PATCH 020/104] Sine is an illegal transform since its Jacobian could be negative --- src/jimgw/prior.py | 2 +- src/jimgw/transforms.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 85414064..1e7cb960 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -193,7 +193,7 @@ class Uniform(Prior): xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax})" + return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" def __init__( self, diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 7ef8560a..55a5d0ce 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -135,10 +135,10 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) -class Sine(UnivariateTransform): +class ArcSine(UnivariateTransform): """ - Transform from unconstrained space to uniform space. - + ArcSine transformation + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -151,4 +151,4 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: jnp.sin(x) + self.transform_func = lambda x: jnp.arcsin(x) From 4978cb1bd8fbf7e1083d0552ebe198ab44de1040 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 25 Jul 2024 18:24:33 -0400 Subject: [PATCH 021/104] Modify Uniform and add UniformSphere --- src/jimgw/prior.py | 143 +++++++++++++++------------------------- src/jimgw/transforms.py | 37 +++++++++-- 2 files changed, 82 insertions(+), 98 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 1e7cb960..effbf388 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcCosine class Prior(Distribution): @@ -49,7 +49,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ return dict(zip(self.parameter_names, x)) - + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -106,9 +106,7 @@ class SequentialTransform(Prior): transforms: list[Transform] def __repr__(self): - return ( - f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" - ) + return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" def __init__( self, @@ -127,24 +125,28 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - + def log_prob(self, x: dict[str, Float]) -> Float: """ log_prob has to be evaluated in the space of the base_prior. - - """ output = self.base_prior.log_prob(x) for transform in self.transforms: x, log_jacobian = transform.transform(x) output -= log_jacobian return output - + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self.base_prior.sample(rng_key, n_samples) + def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) return x + class Combine(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. @@ -184,16 +186,13 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output - @jaxtyped(typechecker=typechecker) -class Uniform(Prior): - _dist: SequentialTransform - +class Uniform(SequentialTransform): xmin: float xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" def __init__( self, @@ -201,94 +200,56 @@ def __init__( xmax: float, parameter_names: list[str], ): - super().__init__(parameter_names) + self.parameter_names = parameter_names assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - self._dist = SequentialTransform( - LogisticDistribution(parameter_names), + super().__init__( + LogisticDistribution(self.parameter_names), [ - Logit((parameter_names, parameter_names)), - Scale((parameter_names, parameter_names), xmax - xmin), - Offset((parameter_names, parameter_names), xmin), - ]) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self._dist.sample(rng_key, n_samples) - - def log_prob(self, x: dict[str, Array]) -> Float: - return self._dist.log_prob(x) - - def sample_base( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self._dist.base_prior.sample(rng_key, n_samples) - - def transform(self, x: dict[str, Float]) -> dict[str, Float]: - return self._dist.transform(x) - -# ====================== Things below may need rework ====================== + Logit((self.parameter_names, self.parameter_names)), + Scale((self.parameter_names, self.parameter_names), xmax - xmin), + Offset((self.parameter_names, self.parameter_names), xmin), + ], + ) -class Sphere(Prior): - """ - A prior on a sphere represented by Cartesian coordinates. - Magnitude is sampled from a uniform distribution. - """ +@jaxtyped(typechecker=typechecker) +class UniformSphere(Combine): def __repr__(self): - return f"Sphere(naming={self.naming})" + return f"UniformSphere(parameter_names={self.parameter_names})" - def __init__(self, naming: list[str], **kwargs): - name = naming[0] - self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] - self.transforms = { - self.naming[0]: ( - f"{naming}_x", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.cos(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[1]: ( - f"{naming}_y", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.sin(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[2]: ( - f"{naming}_z", - lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 3) - theta = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) + def __init__(self, parameter_names: list[str], **kwargs): + assert ( + len(parameter_names) == 1 + ), "UniformSphere only takes the name of the vector" + parameter_names = parameter_names[0] + self.parameter_names = [ + f"{parameter_names}_mag", + f"{parameter_names}_theta", + f"{parameter_names}_phi", + ] + super().__init__( + [ + Uniform(0.0, 1.0, [self.parameter_names[0]]), + SequentialTransform( + Uniform(-1.0, 1.0, [f"cos_{self.parameter_names[1]}"]), + [ + ArcCosine( + ( + [f"cos_{self.parameter_names[1]}"], + [self.parameter_names[1]], + ) + ) + ], + ), + Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + ] ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) - mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) - return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - def log_prob(self, x: dict[str, Float]) -> Float: - theta = x[self.naming[0]] - phi = x[self.naming[1]] - mag = x[self.naming[2]] - output = jnp.where( - (mag > 1) - | (mag < 0) - | (phi > 2 * jnp.pi) - | (phi < 0) - | (theta > jnp.pi) - | (theta < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), - ) - return output + +# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 55a5d0ce..d8ee9536 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from dataclasses import field -from typing import Callable, Union +from typing import Callable import jax import jax.numpy as jnp -from beartype import beartype as typechecker from chex import assert_rank -from jaxtyping import Array, Float, jaxtyped +from jaxtyping import Float + class Transform(ABC): """ @@ -17,6 +16,7 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] + def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -92,7 +92,8 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x - + + class Scale(UnivariateTransform): scale: Float @@ -105,6 +106,7 @@ def __init__( self.scale = scale self.transform_func = lambda x: x * self.scale + class Offset(UnivariateTransform): offset: Float @@ -117,6 +119,7 @@ def __init__( self.offset = offset self.transform_func = lambda x: x + self.offset + class Logit(UnivariateTransform): """ Logit transform following @@ -135,10 +138,11 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + class ArcSine(UnivariateTransform): """ ArcSine transformation - + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -152,3 +156,22 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) + + +class ArcCosine(UnivariateTransform): + """ + ArcCosine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: jnp.arccos(x) From c1115bd5b669bc8a3bcc7e30fe5e7405beb4ba97 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 06:27:36 -0400 Subject: [PATCH 022/104] Add Sine and Cosine Prior --- src/jimgw/prior.py | 50 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index effbf388..894ea9f7 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcCosine +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine class Prior(Distribution): @@ -214,6 +214,42 @@ def __init__( ) +@jaxtyped(typechecker=typechecker) +class Sine(SequentialTransform): + """ + A prior distribution where the pdf is proportional to sin(x) in the range [0, pi]. + """ + + def __repr__(self): + return f"Sine(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 1, "Sine needs to be 1D distributions" + super().__init__( + Uniform(-1.0, 1.0, f"cos_{self.parameter_names}"), + [ArcCosine(([f"cos_{self.parameter_names}"], [self.parameter_names]))], + ) + + +@jaxtyped(typechecker=typechecker) +class Cosine(SequentialTransform): + """ + A prior distribution where the pdf is proportional to cos(x) in the range [-pi/2, pi/2]. + """ + + def __repr__(self): + return f"Cosine(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 1, "Cosine needs to be 1D distributions" + super().__init__( + Uniform(-1.0, 1.0, f"sin_{self.parameter_names}"), + [ArcSine(([f"sin_{self.parameter_names}"], [self.parameter_names]))], + ) + + @jaxtyped(typechecker=typechecker) class UniformSphere(Combine): @@ -233,17 +269,7 @@ def __init__(self, parameter_names: list[str], **kwargs): super().__init__( [ Uniform(0.0, 1.0, [self.parameter_names[0]]), - SequentialTransform( - Uniform(-1.0, 1.0, [f"cos_{self.parameter_names[1]}"]), - [ - ArcCosine( - ( - [f"cos_{self.parameter_names[1]}"], - [self.parameter_names[1]], - ) - ) - ], - ), + Sine([self.parameter_names[1]]), Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] ) From 01a6c1e65ebc4bdc53df48bad87afe1f08d47c58 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 06:57:11 -0400 Subject: [PATCH 023/104] Revert "Modify Uniform and add UniformSphere" This reverts commit 4978cb1bd8fbf7e1083d0552ebe198ab44de1040. --- src/jimgw/prior.py | 143 +++++++++++++++++++++++++--------------- src/jimgw/transforms.py | 37 ++--------- 2 files changed, 98 insertions(+), 82 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index effbf388..1e7cb960 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcCosine +from jimgw.transforms import Transform, Logit, Scale, Offset class Prior(Distribution): @@ -49,7 +49,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ return dict(zip(self.parameter_names, x)) - + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -106,7 +106,9 @@ class SequentialTransform(Prior): transforms: list[Transform] def __repr__(self): - return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" + return ( + f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" + ) def __init__( self, @@ -125,28 +127,24 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - + def log_prob(self, x: dict[str, Float]) -> Float: """ log_prob has to be evaluated in the space of the base_prior. + + """ output = self.base_prior.log_prob(x) for transform in self.transforms: x, log_jacobian = transform.transform(x) output -= log_jacobian return output - - def sample_base( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self.base_prior.sample(rng_key, n_samples) - + def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) return x - class Combine(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. @@ -186,13 +184,16 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output + @jaxtyped(typechecker=typechecker) -class Uniform(SequentialTransform): +class Uniform(Prior): + _dist: SequentialTransform + xmin: float xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" + return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" def __init__( self, @@ -200,56 +201,94 @@ def __init__( xmax: float, parameter_names: list[str], ): - self.parameter_names = parameter_names + super().__init__(parameter_names) assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - super().__init__( - LogisticDistribution(self.parameter_names), + self._dist = SequentialTransform( + LogisticDistribution(parameter_names), [ - Logit((self.parameter_names, self.parameter_names)), - Scale((self.parameter_names, self.parameter_names), xmax - xmin), - Offset((self.parameter_names, self.parameter_names), xmin), - ], - ) + Logit((parameter_names, parameter_names)), + Scale((parameter_names, parameter_names), xmax - xmin), + Offset((parameter_names, parameter_names), xmin), + ]) + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self._dist.sample(rng_key, n_samples) -@jaxtyped(typechecker=typechecker) -class UniformSphere(Combine): + def log_prob(self, x: dict[str, Array]) -> Float: + return self._dist.log_prob(x) + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self._dist.base_prior.sample(rng_key, n_samples) + + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + return self._dist.transform(x) + +# ====================== Things below may need rework ====================== + +class Sphere(Prior): + """ + A prior on a sphere represented by Cartesian coordinates. + + Magnitude is sampled from a uniform distribution. + """ def __repr__(self): - return f"UniformSphere(parameter_names={self.parameter_names})" + return f"Sphere(naming={self.naming})" - def __init__(self, parameter_names: list[str], **kwargs): - assert ( - len(parameter_names) == 1 - ), "UniformSphere only takes the name of the vector" - parameter_names = parameter_names[0] - self.parameter_names = [ - f"{parameter_names}_mag", - f"{parameter_names}_theta", - f"{parameter_names}_phi", - ] - super().__init__( - [ - Uniform(0.0, 1.0, [self.parameter_names[0]]), - SequentialTransform( - Uniform(-1.0, 1.0, [f"cos_{self.parameter_names[1]}"]), - [ - ArcCosine( - ( - [f"cos_{self.parameter_names[1]}"], - [self.parameter_names[1]], - ) - ) - ], - ), - Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), - ] - ) + def __init__(self, naming: list[str], **kwargs): + name = naming[0] + self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] + self.transforms = { + self.naming[0]: ( + f"{naming}_x", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.cos(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[1]: ( + f"{naming}_y", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.sin(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[2]: ( + f"{naming}_z", + lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], + ), + } + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + rng_keys = jax.random.split(rng_key, 3) + theta = jnp.arccos( + jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) + ) + phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) + mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) + return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) -# ====================== Things below may need rework ====================== + def log_prob(self, x: dict[str, Float]) -> Float: + theta = x[self.naming[0]] + phi = x[self.naming[1]] + mag = x[self.naming[2]] + output = jnp.where( + (mag > 1) + | (mag < 0) + | (phi > 2 * jnp.pi) + | (phi < 0) + | (theta > jnp.pi) + | (theta < 0), + jnp.zeros_like(0) - jnp.inf, + jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), + ) + return output @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8ee9536..55a5d0ce 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from typing import Callable +from dataclasses import field +from typing import Callable, Union import jax import jax.numpy as jnp +from beartype import beartype as typechecker from chex import assert_rank -from jaxtyping import Float - +from jaxtyping import Array, Float, jaxtyped class Transform(ABC): """ @@ -16,7 +17,6 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -92,8 +92,7 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x - - + class Scale(UnivariateTransform): scale: Float @@ -106,7 +105,6 @@ def __init__( self.scale = scale self.transform_func = lambda x: x * self.scale - class Offset(UnivariateTransform): offset: Float @@ -119,7 +117,6 @@ def __init__( self.offset = offset self.transform_func = lambda x: x + self.offset - class Logit(UnivariateTransform): """ Logit transform following @@ -138,11 +135,10 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) - class ArcSine(UnivariateTransform): """ ArcSine transformation - + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -156,22 +152,3 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) - - -class ArcCosine(UnivariateTransform): - """ - ArcCosine transformation - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - self.transform_func = lambda x: jnp.arccos(x) From 7ee41335d7820f21f1b777d0482c39a989a5f65a Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:12:10 -0400 Subject: [PATCH 024/104] Add standard normal distribution --- src/jimgw/prior.py | 96 +++++++++++++++++++--------------------------- 1 file changed, 39 insertions(+), 57 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 894ea9f7..3c411c42 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -63,17 +63,17 @@ def log_prob(self, x: dict[str, Array]) -> Float: class LogisticDistribution(Prior): def __repr__(self): - return f"Logistic(parameter_names={self.parameter_names})" + return f"LogisticDistribution(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) - assert self.n_dim == 1, "Logit needs to be 1D distributions" + assert self.n_dim == 1, "LogisticDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: """ - Sample from a logit distribution. + Sample from a logistic distribution. Parameters ---------- @@ -97,6 +97,42 @@ def log_prob(self, x: dict[str, Float]) -> Float: return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) +@jaxtyped(typechecker=typechecker) +class StandardNormalDistribution(Prior): + + def __repr__(self): + return f"StandardNormalDistribution(parameter_names={self.parameter_names})" + + def __init__(self, parameter_names: list[str], **kwargs): + super().__init__(parameter_names) + assert self.n_dim == 1, "StandardNormalDistribution needs to be 1D distributions" + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a standard normal distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.normal(rng_key, (n_samples,)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Float]) -> Float: + variable = x[self.parameter_names[0]] + return -0.5 * variable ** 2 - 0.5 * jnp.log(2 * jnp.pi) + class SequentialTransform(Prior): """ Transform a prior distribution by applying a sequence of transforms. @@ -624,57 +660,3 @@ def log_prob(self, x: dict[str, Float]) -> Float: ) log_p = self.alpha * variable + jnp.log(self.normalization) return log_p + log_in_range - - -@jaxtyped(typechecker=typechecker) -class Normal(Prior): - mean: Float = 0.0 - std: Float = 1.0 - - def __repr__(self): - return f"Normal(mean={self.mean}, std={self.std})" - - def __init__( - self, - mean: Float, - std: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Normal needs to be 1D distributions" - self.mean = mean - self.std = std - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a normal distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - samples = jax.random.normal(rng_key, (n_samples,)) - samples = self.mean + samples * self.std - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = ( - -0.5 * jnp.log(2 * jnp.pi) - - jnp.log(self.std) - - 0.5 * ((variable - self.mean) / self.std) ** 2 - ) - return output From 6a3579296b9cd646d815f49897454af2fc57d7e2 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:18:22 -0400 Subject: [PATCH 025/104] Add periodic uniform prior --- src/jimgw/prior.py | 31 +++++++++++++++++++++++++++++-- src/jimgw/transforms.py | 19 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3c411c42..e2d98f71 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine, Modulo class Prior(Distribution): @@ -249,6 +249,33 @@ def __init__( ], ) +@jaxtyped(typechecker=typechecker) +class PeriodicUniform(SequentialTransform): + xmin: float + xmax: float + + def __repr__(self): + return f"PeriodicUniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" + + def __init__( + self, + xmin: float, + xmax: float, + parameter_names: list[str], + ): + self.parameter_names = parameter_names + assert self.n_dim == 1, "PeriodicUniform needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + super().__init__( + LogisticDistribution(self.parameter_names), + [ + Logit((self.parameter_names, self.parameter_names)), + Modulo((self.parameter_names, self.parameter_names), xmax - xmin), + Offset((self.parameter_names, self.parameter_names), xmin), + ], + ) + @jaxtyped(typechecker=typechecker) class Sine(SequentialTransform): @@ -306,7 +333,7 @@ def __init__(self, parameter_names: list[str], **kwargs): [ Uniform(0.0, 1.0, [self.parameter_names[0]]), Sine([self.parameter_names[1]]), - Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + PeriodicUniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8ee9536..05dacdb1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -138,6 +138,25 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) +class Modulo(UnivariateTransform): + """ + Modulo transform following + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + modulo: Float, + ): + super().__init__(name_mapping) + self.modulo = modulo + self.transform_func = lambda x: jnp.mod(x, self.modulo) class ArcSine(UnivariateTransform): """ From 1bcf32c239c73ef904bda3b1aff86e761dc66bf9 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:20:53 -0400 Subject: [PATCH 026/104] Reformat --- src/jimgw/prior.py | 8 ++++++-- src/jimgw/transforms.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e2d98f71..7dd50018 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -105,7 +105,9 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) - assert self.n_dim == 1, "StandardNormalDistribution needs to be 1D distributions" + assert ( + self.n_dim == 1 + ), "StandardNormalDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int @@ -131,7 +133,8 @@ def sample( def log_prob(self, x: dict[str, Float]) -> Float: variable = x[self.parameter_names[0]] - return -0.5 * variable ** 2 - 0.5 * jnp.log(2 * jnp.pi) + return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) + class SequentialTransform(Prior): """ @@ -249,6 +252,7 @@ def __init__( ], ) + @jaxtyped(typechecker=typechecker) class PeriodicUniform(SequentialTransform): xmin: float diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 05dacdb1..8a4787c5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -138,6 +138,7 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + class Modulo(UnivariateTransform): """ Modulo transform following @@ -158,6 +159,7 @@ def __init__( self.modulo = modulo self.transform_func = lambda x: jnp.mod(x, self.modulo) + class ArcSine(UnivariateTransform): """ ArcSine transformation From 5d98aeb529c165fccde0d14400e6ccb0408fdfd8 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:44:30 -0400 Subject: [PATCH 027/104] Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-class' into sphere_prior" This reverts commit 17450be4adc3c03d24d23c8ccbb6f9af418980e7, reversing changes made to 1bcf32c239c73ef904bda3b1aff86e761dc66bf9. --- src/jimgw/prior.py | 72 +++++++++++++---------------------------- src/jimgw/transforms.py | 36 +++++++++++++++++---- 2 files changed, 51 insertions(+), 57 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 74cc39dd..7dd50018 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -49,7 +49,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ return dict(zip(self.parameter_names, x)) - + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -145,9 +145,7 @@ class SequentialTransform(Prior): transforms: list[Transform] def __repr__(self): - return ( - f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" - ) + return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" def __init__( self, @@ -166,24 +164,28 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - + def log_prob(self, x: dict[str, Float]) -> Float: """ log_prob has to be evaluated in the space of the base_prior. - - """ output = self.base_prior.log_prob(x) for transform in self.transforms: x, log_jacobian = transform.transform(x) output -= log_jacobian return output - + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self.base_prior.sample(rng_key, n_samples) + def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) return x + class Combine(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. @@ -223,16 +225,13 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output - @jaxtyped(typechecker=typechecker) -class Uniform(Prior): - _dist: SequentialTransform - +class Uniform(SequentialTransform): xmin: float xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" def __init__( self, @@ -240,22 +239,19 @@ def __init__( xmax: float, parameter_names: list[str], ): - super().__init__(parameter_names) + self.parameter_names = parameter_names assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - self._dist = SequentialTransform( - LogisticDistribution(parameter_names), + super().__init__( + LogisticDistribution(self.parameter_names), [ - Logit((parameter_names, parameter_names)), - Scale((parameter_names, parameter_names), xmax - xmin), - Offset((parameter_names, parameter_names), xmin), - ]) + Logit((self.parameter_names, self.parameter_names)), + Scale((self.parameter_names, self.parameter_names), xmax - xmin), + Offset((self.parameter_names, self.parameter_names), xmin), + ], + ) - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self._dist.sample(rng_key, n_samples) @jaxtyped(typechecker=typechecker) class PeriodicUniform(SequentialTransform): @@ -325,7 +321,7 @@ def __init__(self, parameter_names: list[str]): class UniformSphere(Combine): def __repr__(self): - return f"Sphere(naming={self.naming})" + return f"UniformSphere(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): assert ( @@ -345,32 +341,8 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 3) - theta = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) - mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) - return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - def log_prob(self, x: dict[str, Float]) -> Float: - theta = x[self.naming[0]] - phi = x[self.naming[1]] - mag = x[self.naming[2]] - output = jnp.where( - (mag > 1) - | (mag < 0) - | (phi > 2 * jnp.pi) - | (phi < 0) - | (theta > jnp.pi) - | (theta < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), - ) - return output +# ====================== Things below may need rework ====================== @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index da8c1591..8a4787c5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from dataclasses import field -from typing import Callable, Union +from typing import Callable import jax import jax.numpy as jnp -from beartype import beartype as typechecker from chex import assert_rank -from jaxtyping import Array, Float, jaxtyped +from jaxtyping import Float + class Transform(ABC): """ @@ -17,6 +16,7 @@ class Transform(ABC): name_mapping: tuple[list[str], list[str]] transform_func: Callable[[dict[str, Float]], dict[str, Float]] + def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -92,7 +92,8 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x - + + class Scale(UnivariateTransform): scale: Float @@ -105,6 +106,7 @@ def __init__( self.scale = scale self.transform_func = lambda x: x * self.scale + class Offset(UnivariateTransform): offset: Float @@ -117,6 +119,7 @@ def __init__( self.offset = offset self.transform_func = lambda x: x + self.offset + class Logit(UnivariateTransform): """ Logit transform following @@ -160,7 +163,7 @@ def __init__( class ArcSine(UnivariateTransform): """ ArcSine transformation - + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -174,3 +177,22 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) + + +class ArcCosine(UnivariateTransform): + """ + ArcCosine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: jnp.arccos(x) From 87ee212ac44d5afa09c40312c4d5570b735ae2de Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 07:47:00 -0400 Subject: [PATCH 028/104] Minor text change --- src/jimgw/prior.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 7dd50018..94363f4c 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -195,9 +195,7 @@ class Combine(Prior): priors: list[Prior] = field(default_factory=list) def __repr__(self): - return ( - f"Composite(priors={self.priors}, parameter_names={self.parameter_names})" - ) + return f"Combine(priors={self.priors}, parameter_names={self.parameter_names})" def __init__( self, From cc1448e2ec120828e5bc1b715cd2ae4c6e01563b Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 08:28:48 -0400 Subject: [PATCH 029/104] Remove PeriodicUniform --- src/jimgw/prior.py | 30 +----------------------------- src/jimgw/transforms.py | 21 --------------------- 2 files changed, 1 insertion(+), 50 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 94363f4c..c571c7dc 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine, Modulo +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine class Prior(Distribution): @@ -251,34 +251,6 @@ def __init__( ) -@jaxtyped(typechecker=typechecker) -class PeriodicUniform(SequentialTransform): - xmin: float - xmax: float - - def __repr__(self): - return f"PeriodicUniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" - - def __init__( - self, - xmin: float, - xmax: float, - parameter_names: list[str], - ): - self.parameter_names = parameter_names - assert self.n_dim == 1, "PeriodicUniform needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - super().__init__( - LogisticDistribution(self.parameter_names), - [ - Logit((self.parameter_names, self.parameter_names)), - Modulo((self.parameter_names, self.parameter_names), xmax - xmin), - Offset((self.parameter_names, self.parameter_names), xmin), - ], - ) - - @jaxtyped(typechecker=typechecker) class Sine(SequentialTransform): """ diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 8a4787c5..d8ee9536 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -139,27 +139,6 @@ def __init__( self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) -class Modulo(UnivariateTransform): - """ - Modulo transform following - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - modulo: Float, - ): - super().__init__(name_mapping) - self.modulo = modulo - self.transform_func = lambda x: jnp.mod(x, self.modulo) - - class ArcSine(UnivariateTransform): """ ArcSine transformation From 6070f130511caf40fde0a4b00ea87f430d80d653 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 08:41:51 -0400 Subject: [PATCH 030/104] Use self.sample_base --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index c571c7dc..53271680 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -162,7 +162,7 @@ def __init__( def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - output = self.base_prior.sample(rng_key, n_samples) + output = self.sample_base(rng_key, n_samples) return jax.vmap(self.transform)(output) def log_prob(self, x: dict[str, Float]) -> Float: From 8a301f2759b7ddc3bb765d4ca8a00de3dd88994c Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 12:27:25 -0400 Subject: [PATCH 031/104] Update prior.py --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 53271680..de378973 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -307,7 +307,7 @@ def __init__(self, parameter_names: list[str], **kwargs): [ Uniform(0.0, 1.0, [self.parameter_names[0]]), Sine([self.parameter_names[1]]), - PeriodicUniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] ) From d761f6bfb83da27d83813f4cb1fce44a2882ebac Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 12:47:01 -0400 Subject: [PATCH 032/104] Update prior.py to include powerLaw --- src/jimgw/prior.py | 191 +++++++++++++++++++++++++++------------------ 1 file changed, 113 insertions(+), 78 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index de378973..637bbc2b 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine, PowerLawTransform, ParetoTransform class Prior(Distribution): @@ -311,6 +311,41 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) +@jaxtyped(typechecker=typechecker) +class PowerLaw(SequentialTransform): + xmin: float + xmax: float + alpha: float + + def __repr__(self): + return f"PowerLaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" + + def __init__( + self, + xmin: float, + xmax: float, + alpha: float, + parameter_names: list[str], + ): + self.parameter_names = parameter_names + assert self.n_dim == 1, "Power law needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + if self.alpha == -1.0: + transform = ParetoTransform((self.parameter_names, self.parameter_names), xmin, xmax) + else: + transform = PowerLawTransform( + (self.parameter_names, self.parameter_names), xmin, xmax, alpha + ) + super().__init__( + LogisticDistribution(self.parameter_names), + [ + Logit((self.parameter_names, self.parameter_names)), + transform, + ], + ) + # ====================== Things below may need rework ====================== @@ -504,83 +539,83 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output + jnp.log(jnp.sin(zenith)) -@jaxtyped(typechecker=typechecker) -class PowerLaw(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ x^{\alpha} - """ - - xmin: float = 0.0 - xmax: float = 1.0 - alpha: float = 0.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: float, - xmax: float, - alpha: Union[Int, float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin > 0.0, "With negative alpha, xmin must > 0" - assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - if alpha == -1: - self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) - else: - self.normalization = (1 + self.alpha) / ( - self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a power-law distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - if self.alpha == -1: - samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) - else: - samples = ( - self.xmin ** (1.0 + self.alpha) - + q_samples - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) - return log_p + log_in_range +# @jaxtyped(typechecker=typechecker) +# class PowerLaw(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ x^{\alpha} +# """ + +# xmin: float = 0.0 +# xmax: float = 1.0 +# alpha: float = 0.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: float, +# xmax: float, +# alpha: Union[Int, float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin > 0.0, "With negative alpha, xmin must > 0" +# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha +# if alpha == -1: +# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) +# else: +# self.normalization = (1 + self.alpha) / ( +# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a power-law distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# if self.alpha == -1: +# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) +# else: +# samples = ( +# self.xmin ** (1.0 + self.alpha) +# + q_samples +# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) +# return log_p + log_in_range @jaxtyped(typechecker=typechecker) From b9a725506012f80db3d297fe66f8e73fc632d37e Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 12:47:51 -0400 Subject: [PATCH 033/104] Update transforms.py --- src/jimgw/transforms.py | 55 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8ee9536..2feac5a2 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -175,3 +175,58 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arccos(x) + + +class PowerLawTransform(UnivariateTransform): + """ + PowerLaw transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + xmin: Float + xmax: Float + alpha: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + alpha: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.alpha = alpha + self.transform_func = lambda x: (self.xmin ** (1.0 + self.alpha)+ x* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)))** (1.0 / (1.0 + self.alpha)), + + + +class ParetoTransform(UnivariateTransform): + """ + Pareto transformation: Power law when alpha = -1 + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.transform_func = lambda x: self.xmin * jnp.exp( + x * jnp.log(self.xmax / self.xmin) + ) From 2f6e12ac73915ccf0e2081ef241f0dc3cf9eb53e Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 16:15:41 -0400 Subject: [PATCH 034/104] format and updating typing hint --- src/jimgw/prior.py | 945 ++++++++++++++++++++-------------------- src/jimgw/transforms.py | 20 +- 2 files changed, 483 insertions(+), 482 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 1e7cb960..77d5713b 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,15 +1,11 @@ from dataclasses import field -from typing import Callable, Union import jax import jax.numpy as jnp -from astropy.time import Time from beartype import beartype as typechecker from flowMC.nfmodel.base import Distribution -from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped +from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped -from jimgw.single_event.detector import GroundBased2G, detector_preset -from jimgw.single_event.utils import zenith_azimuth_to_ra_dec from jimgw.transforms import Transform, Logit, Scale, Offset @@ -49,7 +45,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ return dict(zip(self.parameter_names, x)) - + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: @@ -106,9 +102,7 @@ class SequentialTransform(Prior): transforms: list[Transform] def __repr__(self): - return ( - f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" - ) + return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" def __init__( self, @@ -127,7 +121,7 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - + def log_prob(self, x: dict[str, Float]) -> Float: """ log_prob has to be evaluated in the space of the base_prior. @@ -139,12 +133,13 @@ def log_prob(self, x: dict[str, Float]) -> Float: x, log_jacobian = transform.transform(x) output -= log_jacobian return output - + def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) return x + class Combine(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. @@ -184,7 +179,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output - @jaxtyped(typechecker=typechecker) class Uniform(Prior): _dist: SequentialTransform @@ -211,7 +205,8 @@ def __init__( Logit((parameter_names, parameter_names)), Scale((parameter_names, parameter_names), xmax - xmin), Offset((parameter_names, parameter_names), xmin), - ]) + ], + ) def sample( self, rng_key: PRNGKeyArray, n_samples: int @@ -220,474 +215,476 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: return self._dist.log_prob(x) - + def sample_base( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: return self._dist.base_prior.sample(rng_key, n_samples) - + def transform(self, x: dict[str, Float]) -> dict[str, Float]: return self._dist.transform(x) -# ====================== Things below may need rework ====================== - -class Sphere(Prior): - """ - A prior on a sphere represented by Cartesian coordinates. - - Magnitude is sampled from a uniform distribution. - """ - - def __repr__(self): - return f"Sphere(naming={self.naming})" - - def __init__(self, naming: list[str], **kwargs): - name = naming[0] - self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] - self.transforms = { - self.naming[0]: ( - f"{naming}_x", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.cos(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[1]: ( - f"{naming}_y", - lambda params: jnp.sin(params[self.naming[0]]) - * jnp.sin(params[self.naming[1]]) - * params[self.naming[2]], - ), - self.naming[2]: ( - f"{naming}_z", - lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 3) - theta = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) - mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) - return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - theta = x[self.naming[0]] - phi = x[self.naming[1]] - mag = x[self.naming[2]] - output = jnp.where( - (mag > 1) - | (mag < 0) - | (phi > 2 * jnp.pi) - | (phi < 0) - | (theta > jnp.pi) - | (theta < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), - ) - return output - - -@jaxtyped(typechecker=typechecker) -class AlignedSpin(Prior): - """ - Prior distribution for the aligned (z) component of the spin. - - This assume the prior distribution on the spin magnitude to be uniform in [0, amax] - with its orientation uniform on a sphere - - p(chi) = -log(|chi| / amax) / 2 / amax - - This is useful when comparing results between an aligned-spin run and - a precessing spin run. - - See (A7) of https://arxiv.org/abs/1805.10457. - """ - - amax: Float = 0.99 - chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - - def __repr__(self): - return f"Alignedspin(amax={self.amax}, naming={self.naming})" - - def __init__( - self, - amax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" - self.amax = amax - - # build the interpolation table for the ppf of the one-sided distribution - chi_axis = jnp.linspace(1e-31, self.amax, num=1000) - cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax - self.chi_axis = chi_axis - self.cdf_vals = cdf_vals - - @property - def xmin(self): - return -self.amax - - @property - def xmax(self): - return self.amax - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from the Alignedspin distribution. - for chi > 0; - p(chi) = -log(chi / amax) / amax # halved normalization constant - cdf(chi) = -chi * (log(chi / amax) - 1) / amax - - Since there is a pole at chi=0, we will sample with the following steps - 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise - 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) - 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) - 3. Map the quantile to chi via the ppf by checking against the table - built during the initialization - 4. add back the sign - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - # 1. calculate the sign of chi from the q_samples - sign_samples = jnp.where( - q_samples >= 0.5, - jnp.zeros_like(q_samples) + 1.0, - jnp.zeros_like(q_samples) - 1.0, - ) - # 2. remap q_samples - q_samples = jnp.where( - q_samples >= 0.5, - 2 * (q_samples - 0.5), - 2 * (0.5 - q_samples), - ) - # 3. map the quantile to chi via interpolation - samples = jnp.interp( - q_samples, - self.cdf_vals, - self.chi_axis, - ) - # 4. add back the sign - samples *= sign_samples - - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_p = jnp.where( - (variable >= self.amax) | (variable <= -self.amax), - jnp.zeros_like(variable) - jnp.inf, - jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), - ) - return log_p - - -@jaxtyped(typechecker=typechecker) -class EarthFrame(Prior): - """ - Prior distribution for sky location in Earth frame. - """ - - ifos: list = field(default_factory=list) - gmst: float = 0.0 - delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) - - def __repr__(self): - return f"EarthFrame(naming={self.naming})" - - def __init__(self, gps: Float, ifos: list, **kwargs): - self.naming = ["zenith", "azimuth"] - if len(ifos) < 2: - return ValueError( - "At least two detectors are needed to define the Earth frame" - ) - elif isinstance(ifos[0], str): - self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] - elif isinstance(ifos[0], GroundBased2G): - self.ifos = ifos[:1] - else: - return ValueError( - "ifos should be a list of detector names or GroundBased2G objects" - ) - self.gmst = float( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad - ) - self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - - self.transforms = { - "azimuth": ( - "ra", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[0], - ), - "zenith": ( - "dec", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[1], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 2) - zenith = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - azimuth = jax.random.uniform( - rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi - ) - return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - zenith = x["zenith"] - azimuth = x["azimuth"] - output = jnp.where( - (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.zeros_like(0), - ) - return output + jnp.log(jnp.sin(zenith)) - - -@jaxtyped(typechecker=typechecker) -class PowerLaw(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ x^{\alpha} - """ - - xmin: float = 0.0 - xmax: float = 1.0 - alpha: float = 0.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: float, - xmax: float, - alpha: Union[Int, float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin > 0.0, "With negative alpha, xmin must > 0" - assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - if alpha == -1: - self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) - else: - self.normalization = (1 + self.alpha) / ( - self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a power-law distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - if self.alpha == -1: - samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) - else: - samples = ( - self.xmin ** (1.0 + self.alpha) - + q_samples - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) - return log_p + log_in_range - - -@jaxtyped(typechecker=typechecker) -class Exponential(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ exp(\alpha x) - """ - - xmin: float = 0.0 - xmax: float = jnp.inf - alpha: float = -1.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: Float, - xmax: Float, - alpha: Union[Int, Float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin != -jnp.inf, "With negative alpha, xmin must finite" - if alpha > 0.0: - assert xmax != jnp.inf, "With positive alpha, xmax must finite" - assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" - assert self.n_dim == 1, "Exponential needs to be 1D distributions" - - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - - self.normalization = self.alpha / ( - jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a exponential distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - samples = ( - self.xmin - + jnp.log1p( - q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) - ) - / self.alpha - ) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * variable + jnp.log(self.normalization) - return log_p + log_in_range - - -@jaxtyped(typechecker=typechecker) -class Normal(Prior): - mean: Float = 0.0 - std: Float = 1.0 - - def __repr__(self): - return f"Normal(mean={self.mean}, std={self.std})" - - def __init__( - self, - mean: Float, - std: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Normal needs to be 1D distributions" - self.mean = mean - self.std = std - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a normal distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. +# ====================== Things below may need rework ====================== - """ - samples = jax.random.normal(rng_key, (n_samples,)) - samples = self.mean + samples * self.std - return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Array]) -> Float: - variable = x[self.naming[0]] - output = ( - -0.5 * jnp.log(2 * jnp.pi) - - jnp.log(self.std) - - 0.5 * ((variable - self.mean) / self.std) ** 2 - ) - return output +# class Sphere(Prior): +# """ +# A prior on a sphere represented by Cartesian coordinates. + +# Magnitude is sampled from a uniform distribution. +# """ + +# def __repr__(self): +# return f"Sphere(naming={self.naming})" + +# def __init__(self, naming: list[str], **kwargs): +# name = naming[0] +# self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"] +# self.transforms = { +# self.naming[0]: ( +# f"{naming}_x", +# lambda params: jnp.sin(params[self.naming[0]]) +# * jnp.cos(params[self.naming[1]]) +# * params[self.naming[2]], +# ), +# self.naming[1]: ( +# f"{naming}_y", +# lambda params: jnp.sin(params[self.naming[0]]) +# * jnp.sin(params[self.naming[1]]) +# * params[self.naming[2]], +# ), +# self.naming[2]: ( +# f"{naming}_z", +# lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], +# ), +# } + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# rng_keys = jax.random.split(rng_key, 3) +# theta = jnp.arccos( +# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) +# ) +# phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) +# mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) +# return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# theta = x[self.naming[0]] +# phi = x[self.naming[1]] +# mag = x[self.naming[2]] +# output = jnp.where( +# (mag > 1) +# | (mag < 0) +# | (phi > 2 * jnp.pi) +# | (phi < 0) +# | (theta > jnp.pi) +# | (theta < 0), +# jnp.zeros_like(0) - jnp.inf, +# jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), +# ) +# return output + + +# @jaxtyped(typechecker=typechecker) +# class AlignedSpin(Prior): +# """ +# Prior distribution for the aligned (z) component of the spin. + +# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] +# with its orientation uniform on a sphere + +# p(chi) = -log(|chi| / amax) / 2 / amax + +# This is useful when comparing results between an aligned-spin run and +# a precessing spin run. + +# See (A7) of https://arxiv.org/abs/1805.10457. +# """ + +# amax: Float = 0.99 +# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) +# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + +# def __repr__(self): +# return f"Alignedspin(amax={self.amax}, naming={self.naming})" + +# def __init__( +# self, +# amax: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" +# self.amax = amax + +# # build the interpolation table for the ppf of the one-sided distribution +# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) +# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax +# self.chi_axis = chi_axis +# self.cdf_vals = cdf_vals + +# @property +# def xmin(self): +# return -self.amax + +# @property +# def xmax(self): +# return self.amax + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from the Alignedspin distribution. + +# for chi > 0; +# p(chi) = -log(chi / amax) / amax # halved normalization constant +# cdf(chi) = -chi * (log(chi / amax) - 1) / amax + +# Since there is a pole at chi=0, we will sample with the following steps +# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise +# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) +# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) +# 3. Map the quantile to chi via the ppf by checking against the table +# built during the initialization +# 4. add back the sign + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# # 1. calculate the sign of chi from the q_samples +# sign_samples = jnp.where( +# q_samples >= 0.5, +# jnp.zeros_like(q_samples) + 1.0, +# jnp.zeros_like(q_samples) - 1.0, +# ) +# # 2. remap q_samples +# q_samples = jnp.where( +# q_samples >= 0.5, +# 2 * (q_samples - 0.5), +# 2 * (0.5 - q_samples), +# ) +# # 3. map the quantile to chi via interpolation +# samples = jnp.interp( +# q_samples, +# self.cdf_vals, +# self.chi_axis, +# ) +# # 4. add back the sign +# samples *= sign_samples + +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_p = jnp.where( +# (variable >= self.amax) | (variable <= -self.amax), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), +# ) +# return log_p + + +# @jaxtyped(typechecker=typechecker) +# class EarthFrame(Prior): +# """ +# Prior distribution for sky location in Earth frame. +# """ + +# ifos: list = field(default_factory=list) +# gmst: float = 0.0 +# delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) + +# def __repr__(self): +# return f"EarthFrame(naming={self.naming})" + +# def __init__(self, gps: Float, ifos: list, **kwargs): +# self.naming = ["zenith", "azimuth"] +# if len(ifos) < 2: +# return ValueError( +# "At least two detectors are needed to define the Earth frame" +# ) +# elif isinstance(ifos[0], str): +# self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] +# elif isinstance(ifos[0], GroundBased2G): +# self.ifos = ifos[:1] +# else: +# return ValueError( +# "ifos should be a list of detector names or GroundBased2G objects" +# ) +# self.gmst = float( +# Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad +# ) +# self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex + +# self.transforms = { +# "azimuth": ( +# "ra", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[0], +# ), +# "zenith": ( +# "dec", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[1], +# ), +# } + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# rng_keys = jax.random.split(rng_key, 2) +# zenith = jnp.arccos( +# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) +# ) +# azimuth = jax.random.uniform( +# rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi +# ) +# return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# zenith = x["zenith"] +# azimuth = x["azimuth"] +# output = jnp.where( +# (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), +# jnp.zeros_like(0) - jnp.inf, +# jnp.zeros_like(0), +# ) +# return output + jnp.log(jnp.sin(zenith)) + + +# @jaxtyped(typechecker=typechecker) +# class PowerLaw(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ x^{\alpha} +# """ + +# xmin: float = 0.0 +# xmax: float = 1.0 +# alpha: float = 0.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: float, +# xmax: float, +# alpha: Union[Int, float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin > 0.0, "With negative alpha, xmin must > 0" +# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha +# if alpha == -1: +# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) +# else: +# self.normalization = (1 + self.alpha) / ( +# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a power-law distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# if self.alpha == -1: +# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) +# else: +# samples = ( +# self.xmin ** (1.0 + self.alpha) +# + q_samples +# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) +# return log_p + log_in_range + + +# @jaxtyped(typechecker=typechecker) +# class Exponential(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ exp(\alpha x) +# """ + +# xmin: float = 0.0 +# xmax: float = jnp.inf +# alpha: float = -1.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: Float, +# xmax: Float, +# alpha: Union[Int, Float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin != -jnp.inf, "With negative alpha, xmin must finite" +# if alpha > 0.0: +# assert xmax != jnp.inf, "With positive alpha, xmax must finite" +# assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" +# assert self.n_dim == 1, "Exponential needs to be 1D distributions" + +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha + +# self.normalization = self.alpha / ( +# jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a exponential distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# samples = ( +# self.xmin +# + jnp.log1p( +# q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) +# ) +# / self.alpha +# ) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * variable + jnp.log(self.normalization) +# return log_p + log_in_range + + +# @jaxtyped(typechecker=typechecker) +# class Normal(Prior): +# mean: Float = 0.0 +# std: Float = 1.0 + +# def __repr__(self): +# return f"Normal(mean={self.mean}, std={self.std})" + +# def __init__( +# self, +# mean: Float, +# std: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Normal needs to be 1D distributions" +# self.mean = mean +# self.std = std + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a normal distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# samples = jax.random.normal(rng_key, (n_samples,)) +# samples = self.mean + samples * self.std +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Array]) -> Float: +# variable = x[self.naming[0]] +# output = ( +# -0.5 * jnp.log(2 * jnp.pi) +# - jnp.log(self.std) +# - 0.5 * ((variable - self.mean) / self.std) ** 2 +# ) +# return output diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 55a5d0ce..9e06d6ab 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from dataclasses import field -from typing import Callable, Union +from typing import Callable import jax import jax.numpy as jnp -from beartype import beartype as typechecker from chex import assert_rank -from jaxtyping import Array, Float, jaxtyped +from jaxtyping import Float, Array + class Transform(ABC): """ @@ -16,7 +15,8 @@ class Transform(ABC): """ name_mapping: tuple[list[str], list[str]] - transform_func: Callable[[dict[str, Float]], dict[str, Float]] + transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + def __init__( self, name_mapping: tuple[list[str], list[str]], @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ raise NotImplementedError - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -92,7 +92,8 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: output_params = self.transform_func(input_params) x[self.name_mapping[1][0]] = output_params return x - + + class Scale(UnivariateTransform): scale: Float @@ -105,6 +106,7 @@ def __init__( self.scale = scale self.transform_func = lambda x: x * self.scale + class Offset(UnivariateTransform): offset: Float @@ -117,6 +119,7 @@ def __init__( self.offset = offset self.transform_func = lambda x: x + self.offset + class Logit(UnivariateTransform): """ Logit transform following @@ -135,10 +138,11 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + class ArcSine(UnivariateTransform): """ ArcSine transformation - + Parameters ---------- name_mapping : tuple[list[str], list[str]] From 194c565722beb10fd4a5c41934434583e7de9908 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 12:52:28 -0400 Subject: [PATCH 035/104] Revert "Update transforms.py" This reverts commit b9a725506012f80db3d297fe66f8e73fc632d37e. --- src/jimgw/transforms.py | 55 ----------------------------------------- 1 file changed, 55 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 2feac5a2..d8ee9536 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -175,58 +175,3 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arccos(x) - - -class PowerLawTransform(UnivariateTransform): - """ - PowerLaw transformation - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - xmin: Float - xmax: Float - alpha: Float - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - alpha: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.alpha = alpha - self.transform_func = lambda x: (self.xmin ** (1.0 + self.alpha)+ x* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)))** (1.0 / (1.0 + self.alpha)), - - - -class ParetoTransform(UnivariateTransform): - """ - Pareto transformation: Power law when alpha = -1 - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax / self.xmin) - ) From 5807f3c110564e5215344a5e00755027b8535b59 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 12:52:35 -0400 Subject: [PATCH 036/104] Revert "Update prior.py to include powerLaw" This reverts commit d761f6bfb83da27d83813f4cb1fce44a2882ebac. --- src/jimgw/prior.py | 191 ++++++++++++++++++--------------------------- 1 file changed, 78 insertions(+), 113 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 637bbc2b..de378973 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -10,7 +10,7 @@ from jimgw.single_event.detector import GroundBased2G, detector_preset from jimgw.single_event.utils import zenith_azimuth_to_ra_dec -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine, PowerLawTransform, ParetoTransform +from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine class Prior(Distribution): @@ -311,41 +311,6 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) -@jaxtyped(typechecker=typechecker) -class PowerLaw(SequentialTransform): - xmin: float - xmax: float - alpha: float - - def __repr__(self): - return f"PowerLaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" - - def __init__( - self, - xmin: float, - xmax: float, - alpha: float, - parameter_names: list[str], - ): - self.parameter_names = parameter_names - assert self.n_dim == 1, "Power law needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - if self.alpha == -1.0: - transform = ParetoTransform((self.parameter_names, self.parameter_names), xmin, xmax) - else: - transform = PowerLawTransform( - (self.parameter_names, self.parameter_names), xmin, xmax, alpha - ) - super().__init__( - LogisticDistribution(self.parameter_names), - [ - Logit((self.parameter_names, self.parameter_names)), - transform, - ], - ) - # ====================== Things below may need rework ====================== @@ -539,83 +504,83 @@ def log_prob(self, x: dict[str, Float]) -> Float: return output + jnp.log(jnp.sin(zenith)) -# @jaxtyped(typechecker=typechecker) -# class PowerLaw(Prior): -# """ -# A prior following the power-law with alpha in the range [xmin, xmax). -# p(x) ~ x^{\alpha} -# """ - -# xmin: float = 0.0 -# xmax: float = 1.0 -# alpha: float = 0.0 -# normalization: float = 1.0 - -# def __repr__(self): -# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - -# def __init__( -# self, -# xmin: float, -# xmax: float, -# alpha: Union[Int, float], -# naming: list[str], -# transforms: dict[str, tuple[str, Callable]] = {}, -# **kwargs, -# ): -# super().__init__(naming, transforms) -# if alpha < 0.0: -# assert xmin > 0.0, "With negative alpha, xmin must > 0" -# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" -# self.xmax = xmax -# self.xmin = xmin -# self.alpha = alpha -# if alpha == -1: -# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) -# else: -# self.normalization = (1 + self.alpha) / ( -# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) -# ) - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# """ -# Sample from a power-law distribution. - -# Parameters -# ---------- -# rng_key : PRNGKeyArray -# A random key to use for sampling. -# n_samples : int -# The number of samples to draw. - -# Returns -# ------- -# samples : dict -# Samples from the distribution. The keys are the names of the parameters. - -# """ -# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) -# if self.alpha == -1: -# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) -# else: -# samples = ( -# self.xmin ** (1.0 + self.alpha) -# + q_samples -# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) -# ) ** (1.0 / (1.0 + self.alpha)) -# return self.add_name(samples[None]) - -# def log_prob(self, x: dict[str, Float]) -> Float: -# variable = x[self.naming[0]] -# log_in_range = jnp.where( -# (variable >= self.xmax) | (variable <= self.xmin), -# jnp.zeros_like(variable) - jnp.inf, -# jnp.zeros_like(variable), -# ) -# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) -# return log_p + log_in_range +@jaxtyped(typechecker=typechecker) +class PowerLaw(Prior): + """ + A prior following the power-law with alpha in the range [xmin, xmax). + p(x) ~ x^{\alpha} + """ + + xmin: float = 0.0 + xmax: float = 1.0 + alpha: float = 0.0 + normalization: float = 1.0 + + def __repr__(self): + return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + + def __init__( + self, + xmin: float, + xmax: float, + alpha: Union[Int, float], + naming: list[str], + transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, + ): + super().__init__(naming, transforms) + if alpha < 0.0: + assert xmin > 0.0, "With negative alpha, xmin must > 0" + assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + if alpha == -1: + self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) + else: + self.normalization = (1 + self.alpha) / ( + self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) + ) + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a power-law distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) + if self.alpha == -1: + samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) + else: + samples = ( + self.xmin ** (1.0 + self.alpha) + + q_samples + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) ** (1.0 / (1.0 + self.alpha)) + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Float]) -> Float: + variable = x[self.naming[0]] + log_in_range = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) + log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) + return log_p + log_in_range @jaxtyped(typechecker=typechecker) From 8256d0ee94c4d787d0257280d01461ce5127a8cd Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 13:00:09 -0400 Subject: [PATCH 037/104] Comment out old prior --- src/jimgw/prior.py | 692 ++++++++++++++++++++++----------------------- 1 file changed, 346 insertions(+), 346 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index dc983a50..499bd1c2 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -313,349 +313,349 @@ def __init__(self, parameter_names: list[str], **kwargs): # ====================== Things below may need rework ====================== -@jaxtyped(typechecker=typechecker) -class AlignedSpin(Prior): - """ - Prior distribution for the aligned (z) component of the spin. - - This assume the prior distribution on the spin magnitude to be uniform in [0, amax] - with its orientation uniform on a sphere - - p(chi) = -log(|chi| / amax) / 2 / amax - - This is useful when comparing results between an aligned-spin run and - a precessing spin run. - - See (A7) of https://arxiv.org/abs/1805.10457. - """ - - amax: Float = 0.99 - chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) - - def __repr__(self): - return f"Alignedspin(amax={self.amax}, naming={self.naming})" - - def __init__( - self, - amax: Float, - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" - self.amax = amax - - # build the interpolation table for the ppf of the one-sided distribution - chi_axis = jnp.linspace(1e-31, self.amax, num=1000) - cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax - self.chi_axis = chi_axis - self.cdf_vals = cdf_vals - - @property - def xmin(self): - return -self.amax - - @property - def xmax(self): - return self.amax - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from the Alignedspin distribution. - - for chi > 0; - p(chi) = -log(chi / amax) / amax # halved normalization constant - cdf(chi) = -chi * (log(chi / amax) - 1) / amax - - Since there is a pole at chi=0, we will sample with the following steps - 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise - 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) - 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) - 3. Map the quantile to chi via the ppf by checking against the table - built during the initialization - 4. add back the sign - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - # 1. calculate the sign of chi from the q_samples - sign_samples = jnp.where( - q_samples >= 0.5, - jnp.zeros_like(q_samples) + 1.0, - jnp.zeros_like(q_samples) - 1.0, - ) - # 2. remap q_samples - q_samples = jnp.where( - q_samples >= 0.5, - 2 * (q_samples - 0.5), - 2 * (0.5 - q_samples), - ) - # 3. map the quantile to chi via interpolation - samples = jnp.interp( - q_samples, - self.cdf_vals, - self.chi_axis, - ) - # 4. add back the sign - samples *= sign_samples - - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_p = jnp.where( - (variable >= self.amax) | (variable <= -self.amax), - jnp.zeros_like(variable) - jnp.inf, - jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), - ) - return log_p - - -@jaxtyped(typechecker=typechecker) -class EarthFrame(Prior): - """ - Prior distribution for sky location in Earth frame. - """ - - ifos: list = field(default_factory=list) - gmst: float = 0.0 - delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) - - def __repr__(self): - return f"EarthFrame(naming={self.naming})" - - def __init__(self, gps: Float, ifos: list, **kwargs): - self.naming = ["zenith", "azimuth"] - if len(ifos) < 2: - return ValueError( - "At least two detectors are needed to define the Earth frame" - ) - elif isinstance(ifos[0], str): - self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] - elif isinstance(ifos[0], GroundBased2G): - self.ifos = ifos[:1] - else: - return ValueError( - "ifos should be a list of detector names or GroundBased2G objects" - ) - self.gmst = float( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad - ) - self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - - self.transforms = { - "azimuth": ( - "ra", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[0], - ), - "zenith": ( - "dec", - lambda params: zenith_azimuth_to_ra_dec( - params["zenith"], - params["azimuth"], - gmst=self.gmst, - delta_x=self.delta_x, - )[1], - ), - } - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - rng_keys = jax.random.split(rng_key, 2) - zenith = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) - azimuth = jax.random.uniform( - rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi - ) - return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) - - def log_prob(self, x: dict[str, Float]) -> Float: - zenith = x["zenith"] - azimuth = x["azimuth"] - output = jnp.where( - (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), - jnp.zeros_like(0) - jnp.inf, - jnp.zeros_like(0), - ) - return output + jnp.log(jnp.sin(zenith)) - - -@jaxtyped(typechecker=typechecker) -class PowerLaw(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ x^{\alpha} - """ - - xmin: float = 0.0 - xmax: float = 1.0 - alpha: float = 0.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: float, - xmax: float, - alpha: Union[Int, float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin > 0.0, "With negative alpha, xmin must > 0" - assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - if alpha == -1: - self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) - else: - self.normalization = (1 + self.alpha) / ( - self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a power-law distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - if self.alpha == -1: - samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) - else: - samples = ( - self.xmin ** (1.0 + self.alpha) - + q_samples - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) - return log_p + log_in_range - - -@jaxtyped(typechecker=typechecker) -class Exponential(Prior): - """ - A prior following the power-law with alpha in the range [xmin, xmax). - p(x) ~ exp(\alpha x) - """ - - xmin: float = 0.0 - xmax: float = jnp.inf - alpha: float = -1.0 - normalization: float = 1.0 - - def __repr__(self): - return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - - def __init__( - self, - xmin: Float, - xmax: Float, - alpha: Union[Int, Float], - naming: list[str], - transforms: dict[str, tuple[str, Callable]] = {}, - **kwargs, - ): - super().__init__(naming, transforms) - if alpha < 0.0: - assert xmin != -jnp.inf, "With negative alpha, xmin must finite" - if alpha > 0.0: - assert xmax != jnp.inf, "With positive alpha, xmax must finite" - assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" - assert self.n_dim == 1, "Exponential needs to be 1D distributions" - - self.xmax = xmax - self.xmin = xmin - self.alpha = alpha - - self.normalization = self.alpha / ( - jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) - ) - - def sample( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - """ - Sample from a exponential distribution. - - Parameters - ---------- - rng_key : PRNGKeyArray - A random key to use for sampling. - n_samples : int - The number of samples to draw. - - Returns - ------- - samples : dict - Samples from the distribution. The keys are the names of the parameters. - - """ - q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) - samples = ( - self.xmin - + jnp.log1p( - q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) - ) - / self.alpha - ) - return self.add_name(samples[None]) - - def log_prob(self, x: dict[str, Float]) -> Float: - variable = x[self.naming[0]] - log_in_range = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), - ) - log_p = self.alpha * variable + jnp.log(self.normalization) - return log_p + log_in_range +# @jaxtyped(typechecker=typechecker) +# class AlignedSpin(Prior): +# """ +# Prior distribution for the aligned (z) component of the spin. + +# This assume the prior distribution on the spin magnitude to be uniform in [0, amax] +# with its orientation uniform on a sphere + +# p(chi) = -log(|chi| / amax) / 2 / amax + +# This is useful when comparing results between an aligned-spin run and +# a precessing spin run. + +# See (A7) of https://arxiv.org/abs/1805.10457. +# """ + +# amax: Float = 0.99 +# chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) +# cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + +# def __repr__(self): +# return f"Alignedspin(amax={self.amax}, naming={self.naming})" + +# def __init__( +# self, +# amax: Float, +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" +# self.amax = amax + +# # build the interpolation table for the ppf of the one-sided distribution +# chi_axis = jnp.linspace(1e-31, self.amax, num=1000) +# cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax +# self.chi_axis = chi_axis +# self.cdf_vals = cdf_vals + +# @property +# def xmin(self): +# return -self.amax + +# @property +# def xmax(self): +# return self.amax + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from the Alignedspin distribution. + +# for chi > 0; +# p(chi) = -log(chi / amax) / amax # halved normalization constant +# cdf(chi) = -chi * (log(chi / amax) - 1) / amax + +# Since there is a pole at chi=0, we will sample with the following steps +# 1. Map the samples with quantile > 0.5 to positive chi and negative otherwise +# 2a. For negative chi, map the quantile back to [0, 1] via q -> 2(0.5 - q) +# 2b. For positive chi, map the quantile back to [0, 1] via q -> 2(q - 0.5) +# 3. Map the quantile to chi via the ppf by checking against the table +# built during the initialization +# 4. add back the sign + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# # 1. calculate the sign of chi from the q_samples +# sign_samples = jnp.where( +# q_samples >= 0.5, +# jnp.zeros_like(q_samples) + 1.0, +# jnp.zeros_like(q_samples) - 1.0, +# ) +# # 2. remap q_samples +# q_samples = jnp.where( +# q_samples >= 0.5, +# 2 * (q_samples - 0.5), +# 2 * (0.5 - q_samples), +# ) +# # 3. map the quantile to chi via interpolation +# samples = jnp.interp( +# q_samples, +# self.cdf_vals, +# self.chi_axis, +# ) +# # 4. add back the sign +# samples *= sign_samples + +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_p = jnp.where( +# (variable >= self.amax) | (variable <= -self.amax), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), +# ) +# return log_p + + +# @jaxtyped(typechecker=typechecker) +# class EarthFrame(Prior): +# """ +# Prior distribution for sky location in Earth frame. +# """ + +# ifos: list = field(default_factory=list) +# gmst: float = 0.0 +# delta_x: Float[Array, " 3"] = field(default_factory=lambda: jnp.zeros(3)) + +# def __repr__(self): +# return f"EarthFrame(naming={self.naming})" + +# def __init__(self, gps: Float, ifos: list, **kwargs): +# self.naming = ["zenith", "azimuth"] +# if len(ifos) < 2: +# return ValueError( +# "At least two detectors are needed to define the Earth frame" +# ) +# elif isinstance(ifos[0], str): +# self.ifos = [detector_preset[ifos[0]], detector_preset[ifos[1]]] +# elif isinstance(ifos[0], GroundBased2G): +# self.ifos = ifos[:1] +# else: +# return ValueError( +# "ifos should be a list of detector names or GroundBased2G objects" +# ) +# self.gmst = float( +# Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad +# ) +# self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex + +# self.transforms = { +# "azimuth": ( +# "ra", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[0], +# ), +# "zenith": ( +# "dec", +# lambda params: zenith_azimuth_to_ra_dec( +# params["zenith"], +# params["azimuth"], +# gmst=self.gmst, +# delta_x=self.delta_x, +# )[1], +# ), +# } + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# rng_keys = jax.random.split(rng_key, 2) +# zenith = jnp.arccos( +# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) +# ) +# azimuth = jax.random.uniform( +# rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi +# ) +# return self.add_name(jnp.stack([zenith, azimuth], axis=1).T) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# zenith = x["zenith"] +# azimuth = x["azimuth"] +# output = jnp.where( +# (zenith > jnp.pi) | (zenith < 0) | (azimuth > 2 * jnp.pi) | (azimuth < 0), +# jnp.zeros_like(0) - jnp.inf, +# jnp.zeros_like(0), +# ) +# return output + jnp.log(jnp.sin(zenith)) + + +# @jaxtyped(typechecker=typechecker) +# class PowerLaw(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ x^{\alpha} +# """ + +# xmin: float = 0.0 +# xmax: float = 1.0 +# alpha: float = 0.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: float, +# xmax: float, +# alpha: Union[Int, float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin > 0.0, "With negative alpha, xmin must > 0" +# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha +# if alpha == -1: +# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) +# else: +# self.normalization = (1 + self.alpha) / ( +# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a power-law distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# if self.alpha == -1: +# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) +# else: +# samples = ( +# self.xmin ** (1.0 + self.alpha) +# + q_samples +# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) +# return log_p + log_in_range + + +# @jaxtyped(typechecker=typechecker) +# class Exponential(Prior): +# """ +# A prior following the power-law with alpha in the range [xmin, xmax). +# p(x) ~ exp(\alpha x) +# """ + +# xmin: float = 0.0 +# xmax: float = jnp.inf +# alpha: float = -1.0 +# normalization: float = 1.0 + +# def __repr__(self): +# return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + +# def __init__( +# self, +# xmin: Float, +# xmax: Float, +# alpha: Union[Int, Float], +# naming: list[str], +# transforms: dict[str, tuple[str, Callable]] = {}, +# **kwargs, +# ): +# super().__init__(naming, transforms) +# if alpha < 0.0: +# assert xmin != -jnp.inf, "With negative alpha, xmin must finite" +# if alpha > 0.0: +# assert xmax != jnp.inf, "With positive alpha, xmax must finite" +# assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" +# assert self.n_dim == 1, "Exponential needs to be 1D distributions" + +# self.xmax = xmax +# self.xmin = xmin +# self.alpha = alpha + +# self.normalization = self.alpha / ( +# jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) +# ) + +# def sample( +# self, rng_key: PRNGKeyArray, n_samples: int +# ) -> dict[str, Float[Array, " n_samples"]]: +# """ +# Sample from a exponential distribution. + +# Parameters +# ---------- +# rng_key : PRNGKeyArray +# A random key to use for sampling. +# n_samples : int +# The number of samples to draw. + +# Returns +# ------- +# samples : dict +# Samples from the distribution. The keys are the names of the parameters. + +# """ +# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) +# samples = ( +# self.xmin +# + jnp.log1p( +# q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) +# ) +# / self.alpha +# ) +# return self.add_name(samples[None]) + +# def log_prob(self, x: dict[str, Float]) -> Float: +# variable = x[self.naming[0]] +# log_in_range = jnp.where( +# (variable >= self.xmax) | (variable <= self.xmin), +# jnp.zeros_like(variable) - jnp.inf, +# jnp.zeros_like(variable), +# ) +# log_p = self.alpha * variable + jnp.log(self.normalization) +# return log_p + log_in_range From b42b0f4622b2a457ae759eabbbbc6c6641af8da4 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 13:01:07 -0400 Subject: [PATCH 038/104] Reformat --- src/jimgw/prior.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 499bd1c2..5e48cb8a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -6,8 +6,6 @@ from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped -from jimgw.single_event.detector import GroundBased2G, detector_preset -from jimgw.single_event.utils import zenith_azimuth_to_ra_dec from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine From a9629b25e984c104375a2710fd2c8e496faaa5b8 Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 26 Jul 2024 13:09:46 -0400 Subject: [PATCH 039/104] Update Prior naming --- src/jimgw/prior.py | 34 +++++++++++++++++----------------- test/test_prior.py | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 5e48cb8a..0f8b70b1 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -132,7 +132,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) -class SequentialTransform(Prior): +class SequentialTransformPrior(Prior): """ Transform a prior distribution by applying a sequence of transforms. """ @@ -182,7 +182,7 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]: return x -class Combine(Prior): +class CombinePrior(Prior): """ A prior class constructed by joinning multiple priors together to form a multivariate prior. This assumes the priors composing the Combine class are independent. @@ -220,12 +220,12 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped(typechecker=typechecker) -class Uniform(SequentialTransform): +class UniformPrior(SequentialTransformPrior): xmin: float xmax: float def __repr__(self): - return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" + return f"UniformPrior(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})" def __init__( self, @@ -234,7 +234,7 @@ def __init__( parameter_names: list[str], ): self.parameter_names = parameter_names - assert self.n_dim == 1, "Uniform needs to be 1D distributions" + assert self.n_dim == 1, "UniformPrior needs to be 1D distributions" self.xmax = xmax self.xmin = xmin super().__init__( @@ -248,43 +248,43 @@ def __init__( @jaxtyped(typechecker=typechecker) -class Sine(SequentialTransform): +class SinePrior(SequentialTransformPrior): """ A prior distribution where the pdf is proportional to sin(x) in the range [0, pi]. """ def __repr__(self): - return f"Sine(parameter_names={self.parameter_names})" + return f"SinePrior(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names - assert self.n_dim == 1, "Sine needs to be 1D distributions" + assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( - Uniform(-1.0, 1.0, f"cos_{self.parameter_names}"), + UniformPrior(-1.0, 1.0, f"cos_{self.parameter_names}"), [ArcCosine(([f"cos_{self.parameter_names}"], [self.parameter_names]))], ) @jaxtyped(typechecker=typechecker) -class Cosine(SequentialTransform): +class CosinePrior(SequentialTransformPrior): """ A prior distribution where the pdf is proportional to cos(x) in the range [-pi/2, pi/2]. """ def __repr__(self): - return f"Cosine(parameter_names={self.parameter_names})" + return f"CosinePrior(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names - assert self.n_dim == 1, "Cosine needs to be 1D distributions" + assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( - Uniform(-1.0, 1.0, f"sin_{self.parameter_names}"), + UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names}"), [ArcSine(([f"sin_{self.parameter_names}"], [self.parameter_names]))], ) @jaxtyped(typechecker=typechecker) -class UniformSphere(Combine): +class UniformSpherePrior(CombinePrior): def __repr__(self): return f"UniformSphere(parameter_names={self.parameter_names})" @@ -301,9 +301,9 @@ def __init__(self, parameter_names: list[str], **kwargs): ] super().__init__( [ - Uniform(0.0, 1.0, [self.parameter_names[0]]), - Sine([self.parameter_names[1]]), - Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]), + UniformPrior(0.0, 1.0, [self.parameter_names[0]]), + SinePrior([self.parameter_names[1]]), + UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] ) diff --git a/test/test_prior.py b/test/test_prior.py index b6eb7c87..98803075 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -6,7 +6,7 @@ def test_logistic(self): p = Logit() def test_uniform(self): - p = Uniform(0.0, 10.0, ['x']) + p = UniformPrior(0.0, 10.0, ['x']) samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, -jnp.log(10.0)) From fbba5af599f87fae064e12d6d593398d48b3f8ce Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 26 Jul 2024 13:12:13 -0400 Subject: [PATCH 040/104] Standize Transform naming --- src/jimgw/prior.py | 29 +++++++++++++++++++++++------ src/jimgw/transforms.py | 10 +++++----- test/test_prior.py | 11 ++++++----- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 0f8b70b1..3d26eb04 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -6,7 +6,14 @@ from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped -from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine +from jimgw.transforms import ( + Transform, + LogitTransform, + ScaleTransform, + OffsetTransform, + ArcSineTransform, + ArcCosineTransform, +) class Prior(Distribution): @@ -240,9 +247,11 @@ def __init__( super().__init__( LogisticDistribution(self.parameter_names), [ - Logit((self.parameter_names, self.parameter_names)), - Scale((self.parameter_names, self.parameter_names), xmax - xmin), - Offset((self.parameter_names, self.parameter_names), xmin), + LogitTransform((self.parameter_names, self.parameter_names)), + ScaleTransform( + (self.parameter_names, self.parameter_names), xmax - xmin + ), + OffsetTransform((self.parameter_names, self.parameter_names), xmin), ], ) @@ -261,7 +270,11 @@ def __init__(self, parameter_names: list[str]): assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( UniformPrior(-1.0, 1.0, f"cos_{self.parameter_names}"), - [ArcCosine(([f"cos_{self.parameter_names}"], [self.parameter_names]))], + [ + ArcCosineTransform( + ([f"cos_{self.parameter_names}"], [self.parameter_names]) + ) + ], ) @@ -279,7 +292,11 @@ def __init__(self, parameter_names: list[str]): assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names}"), - [ArcSine(([f"sin_{self.parameter_names}"], [self.parameter_names]))], + [ + ArcSineTransform( + ([f"sin_{self.parameter_names}"], [self.parameter_names]) + ) + ], ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d8ee9536..5022fc88 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -94,7 +94,7 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: return x -class Scale(UnivariateTransform): +class ScaleTransform(UnivariateTransform): scale: Float def __init__( @@ -107,7 +107,7 @@ def __init__( self.transform_func = lambda x: x * self.scale -class Offset(UnivariateTransform): +class OffsetTransform(UnivariateTransform): offset: Float def __init__( @@ -120,7 +120,7 @@ def __init__( self.transform_func = lambda x: x + self.offset -class Logit(UnivariateTransform): +class LogitTransform(UnivariateTransform): """ Logit transform following @@ -139,7 +139,7 @@ def __init__( self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) -class ArcSine(UnivariateTransform): +class ArcSineTransform(UnivariateTransform): """ ArcSine transformation @@ -158,7 +158,7 @@ def __init__( self.transform_func = lambda x: jnp.arcsin(x) -class ArcCosine(UnivariateTransform): +class ArcCosineTransform(UnivariateTransform): """ ArcCosine transformation diff --git a/test/test_prior.py b/test/test_prior.py index 98803075..6431d7d0 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -3,20 +3,21 @@ class TestUnivariatePrior: def test_logistic(self): - p = Logit() + p = LogitTransform() def test_uniform(self): - p = UniformPrior(0.0, 10.0, ['x']) + p = UniformPrior(0.0, 10.0, ["x"]) samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, -jnp.log(10.0)) + class TestPriorOperations: def test_combine(self): raise NotImplementedError - + def test_sequence(self): raise NotImplementedError - + def test_factor(self): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError From 4a5d932aea617a36d2e7e26a7bf61ff0007953d6 Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 26 Jul 2024 15:13:55 -0400 Subject: [PATCH 041/104] Fixing naming problem and add base_distribution tracer --- src/jimgw/prior.py | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3d26eb04..dc6a492b 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -27,6 +27,7 @@ class Prior(Distribution): """ parameter_names: list[str] + composite: bool = False @property def n_dim(self): @@ -61,7 +62,6 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: raise NotImplementedError - @jaxtyped(typechecker=typechecker) class LogisticDistribution(Prior): @@ -70,6 +70,7 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) + self.composite = False assert self.n_dim == 1, "LogisticDistribution needs to be 1D distributions" def sample( @@ -108,6 +109,7 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) + self.composite = False assert ( self.n_dim == 1 ), "StandardNormalDistribution needs to be 1D distributions" @@ -161,11 +163,12 @@ def __init__( self.parameter_names = base_prior.parameter_names for transform in transforms: self.parameter_names = transform.propagate_name(self.parameter_names) + self.composite = True def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: - output = self.sample_base(rng_key, n_samples) + output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) def log_prob(self, x: dict[str, Float]) -> Float: @@ -178,11 +181,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: output -= log_jacobian return output - def sample_base( - self, rng_key: PRNGKeyArray, n_samples: int - ) -> dict[str, Float[Array, " n_samples"]]: - return self.base_prior.sample(rng_key, n_samples) - def transform(self, x: dict[str, Float]) -> dict[str, Float]: for transform in self.transforms: x = transform.forward(x) @@ -195,10 +193,12 @@ class CombinePrior(Prior): This assumes the priors composing the Combine class are independent. """ - priors: list[Prior] = field(default_factory=list) + base_prior: list[Prior] = field(default_factory=list) def __repr__(self): - return f"Combine(priors={self.priors}, parameter_names={self.parameter_names})" + return ( + f"Combine(priors={self.base_prior}, parameter_names={self.parameter_names})" + ) def __init__( self, @@ -207,21 +207,22 @@ def __init__( parameter_names = [] for prior in priors: parameter_names += prior.parameter_names - self.priors = priors + self.base_prior = priors self.parameter_names = parameter_names + self.composite = True def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: output = {} - for prior in self.priors: + for prior in self.base_prior: rng_key, subkey = jax.random.split(rng_key) output.update(prior.sample(subkey, n_samples)) return output def log_prob(self, x: dict[str, Float]) -> Float: output = 0.0 - for prior in self.priors: + for prior in self.base_prior: output += prior.log_prob(x) return output @@ -269,10 +270,10 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, f"cos_{self.parameter_names}"), + UniformPrior(-1.0, 1.0, [f"cos_{self.parameter_names[0]}"]), [ ArcCosineTransform( - ([f"cos_{self.parameter_names}"], [self.parameter_names]) + ([f"cos_{self.parameter_names[0]}"], [self.parameter_names[0]]) ) ], ) @@ -291,10 +292,10 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names}"), + UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names[0]}"), [ ArcSineTransform( - ([f"sin_{self.parameter_names}"], [self.parameter_names]) + ([f"sin_{self.parameter_names[0]}"], [self.parameter_names[0]]) ) ], ) @@ -324,6 +325,17 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output # ====================== Things below may need rework ====================== From 27dc87090ff3ed29cb52792f957d52de3595a6c8 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:37:19 -0400 Subject: [PATCH 042/104] Updated powerlaw --- src/jimgw/prior.py | 43 +++++++++++++++++++++++++++++++ src/jimgw/transforms.py | 57 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index dc6a492b..16c28b8e 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -13,6 +13,8 @@ OffsetTransform, ArcSineTransform, ArcCosineTransform, + PowerLawTransform, + ParetoTransform, ) @@ -62,6 +64,7 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: raise NotImplementedError + @jaxtyped(typechecker=typechecker) class LogisticDistribution(Prior): @@ -325,6 +328,45 @@ def __init__(self, parameter_names: list[str], **kwargs): ] ) + +@jaxtyped(typechecker=typechecker) +class PowerLawPrior(SequentialTransformPrior): + xmin: float + xmax: float + alpha: float + + def __repr__(self): + return f"PowerLaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" + + def __init__( + self, + xmin: float, + xmax: float, + alpha: float, + parameter_names: list[str], + ): + self.parameter_names = parameter_names + assert self.n_dim == 1, "Power law needs to be 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + if self.alpha == -1.0: + transform = ParetoTransform( + (self.parameter_names, self.parameter_names), xmin, xmax + ) + else: + transform = PowerLawTransform( + (self.parameter_names, self.parameter_names), xmin, xmax, alpha + ) + super().__init__( + LogisticDistribution(self.parameter_names), + [ + Logit((self.parameter_names, self.parameter_names)), + transform, + ], + ) + + def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: if prior.composite: if isinstance(prior.base_prior, list): @@ -337,6 +379,7 @@ def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: return output + # ====================== Things below may need rework ====================== diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 5022fc88..527ea7b8 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -175,3 +175,60 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arccos(x) + + +class PowerLawTransform(UnivariateTransform): + """ + PowerLaw transformation + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + xmin: Float + xmax: Float + alpha: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + alpha: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.alpha = alpha + self.transform_func = ( + lambda x: ( + self.xmin ** (1.0 + self.alpha) + + x + * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) + ** (1.0 / (1.0 + self.alpha)), + ) + + +class ParetoTransform(UnivariateTransform): + """ + Pareto transformation: Power law when alpha = -1 + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + xmin: Float, + xmax: Float, + ): + super().__init__(name_mapping) + self.xmin = xmin + self.xmax = xmax + self.transform_func = lambda x: self.xmin * jnp.exp( + x * jnp.log(self.xmax / self.xmin) + ) From f7b883d6d0bff8f5be1eaddfbd042e601b1b620a Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:38:28 -0400 Subject: [PATCH 043/104] Updated powerlaw --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 16c28b8e..9f52fdbb 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -361,7 +361,7 @@ def __init__( super().__init__( LogisticDistribution(self.parameter_names), [ - Logit((self.parameter_names, self.parameter_names)), + LogitTransform((self.parameter_names, self.parameter_names)), transform, ], ) From 1665b1cfbc0400647f1ccabb2ed79479c4f32c43 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:26:20 -0400 Subject: [PATCH 044/104] Updated powerlaw --- src/jimgw/test.py | 70 +++++++++++++++++++++++++++++++++++++++++ src/jimgw/transforms.py | 12 +++---- 2 files changed, 74 insertions(+), 8 deletions(-) create mode 100644 src/jimgw/test.py diff --git a/src/jimgw/test.py b/src/jimgw/test.py new file mode 100644 index 00000000..c114b190 --- /dev/null +++ b/src/jimgw/test.py @@ -0,0 +1,70 @@ +import jax +import jax.numpy as jnp +from jimgw.prior import PowerLawPrior + +alpha = -1.0 +xmin = 1.0 +xmax = 10.0 + + +q_samples = jax.random.uniform(jax.random.PRNGKey(42), (100,), minval=0.0, maxval=1.0) +if alpha == -1: + samples = xmin * jnp.exp(q_samples * jnp.log(xmax / xmin)) +else: + samples = ( + xmin ** (1.0 + alpha) + + q_samples * (xmax ** (1.0 + alpha) - xmin ** (1.0 + alpha)) + ) ** (1.0 / (1.0 + alpha)) +samples[None] + +PowerLawPrior(xmin, xmax, alpha, ["x"]).sample(jax.random.PRNGKey(42), 100) + + +############## +if alpha == -1: + normalization = float(1.0 / jnp.log(xmax / xmin)) +else: + normalization = (1 + alpha) / (xmax ** (1 + alpha) - xmin ** (1 + alpha)) +variable = 1.5 +log_in_range = jnp.where( + (variable >= xmax) | (variable <= xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), +) +log_p = alpha * jnp.log(variable) + jnp.log(normalization) +log_p + log_in_range + +PowerLawPrior(xmin, xmax, alpha, ["x"]).log_prob({"x": variable}) + + +# transform_func = lambda x: ( +# xmin ** (1.0 + alpha) + x * (xmax ** (1.0 + alpha) - xmin ** (1.0 + alpha)) +# ) ** (1.0 / (1.0 + alpha)) +# input_params = variable +# assert_rank(input_params, 0) +# output_params = transform_func(input_params) +# jacobian = jax.jacfwd(transform_func)(input_params) +# x[variable] = output_params +# return x, jnp.log(jacobian) + + +import numpy as np + +alpha = -2.0 +xmin = 1.0 +xmax = 20.0 +p = PowerLawPrior(xmin, xmax, alpha, ["x"]) +grid = np.linspace(xmin, xmax, 20) +transform = [] +log_prob = [] +for y in grid: + transform.append(p.transform({"x": y})["x"].item()) + log_prob.append(np.exp(p.log_prob({"x": y}).item())) +import matplotlib.pyplot as plt + +plt.plot(grid, transform) +plt.savefig("transform.png") +plt.close() +plt.plot(transform, log_prob) +plt.savefig("log_prob.png") +plt.close() diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 527ea7b8..da84a059 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -201,14 +201,10 @@ def __init__( self.xmin = xmin self.xmax = xmax self.alpha = alpha - self.transform_func = ( - lambda x: ( - self.xmin ** (1.0 + self.alpha) - + x - * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) - ** (1.0 / (1.0 + self.alpha)), - ) + self.transform_func = lambda x: ( + self.xmin ** (1.0 + self.alpha) + + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) + ) ** (1.0 / (1.0 + self.alpha)) class ParetoTransform(UnivariateTransform): From f51847d614d14de03569317e28cc738a9d89d915 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:49:15 -0400 Subject: [PATCH 045/104] Updated powerlaw --- src/jimgw/test.py | 70 ----------------------------------------------- 1 file changed, 70 deletions(-) delete mode 100644 src/jimgw/test.py diff --git a/src/jimgw/test.py b/src/jimgw/test.py deleted file mode 100644 index c114b190..00000000 --- a/src/jimgw/test.py +++ /dev/null @@ -1,70 +0,0 @@ -import jax -import jax.numpy as jnp -from jimgw.prior import PowerLawPrior - -alpha = -1.0 -xmin = 1.0 -xmax = 10.0 - - -q_samples = jax.random.uniform(jax.random.PRNGKey(42), (100,), minval=0.0, maxval=1.0) -if alpha == -1: - samples = xmin * jnp.exp(q_samples * jnp.log(xmax / xmin)) -else: - samples = ( - xmin ** (1.0 + alpha) - + q_samples * (xmax ** (1.0 + alpha) - xmin ** (1.0 + alpha)) - ) ** (1.0 / (1.0 + alpha)) -samples[None] - -PowerLawPrior(xmin, xmax, alpha, ["x"]).sample(jax.random.PRNGKey(42), 100) - - -############## -if alpha == -1: - normalization = float(1.0 / jnp.log(xmax / xmin)) -else: - normalization = (1 + alpha) / (xmax ** (1 + alpha) - xmin ** (1 + alpha)) -variable = 1.5 -log_in_range = jnp.where( - (variable >= xmax) | (variable <= xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable), -) -log_p = alpha * jnp.log(variable) + jnp.log(normalization) -log_p + log_in_range - -PowerLawPrior(xmin, xmax, alpha, ["x"]).log_prob({"x": variable}) - - -# transform_func = lambda x: ( -# xmin ** (1.0 + alpha) + x * (xmax ** (1.0 + alpha) - xmin ** (1.0 + alpha)) -# ) ** (1.0 / (1.0 + alpha)) -# input_params = variable -# assert_rank(input_params, 0) -# output_params = transform_func(input_params) -# jacobian = jax.jacfwd(transform_func)(input_params) -# x[variable] = output_params -# return x, jnp.log(jacobian) - - -import numpy as np - -alpha = -2.0 -xmin = 1.0 -xmax = 20.0 -p = PowerLawPrior(xmin, xmax, alpha, ["x"]) -grid = np.linspace(xmin, xmax, 20) -transform = [] -log_prob = [] -for y in grid: - transform.append(p.transform({"x": y})["x"].item()) - log_prob.append(np.exp(p.log_prob({"x": y}).item())) -import matplotlib.pyplot as plt - -plt.plot(grid, transform) -plt.savefig("transform.png") -plt.close() -plt.plot(transform, log_prob) -plt.savefig("log_prob.png") -plt.close() From f7876e6403a2ce90db95bcbd11a3daa335730b0f Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 26 Jul 2024 17:09:42 -0400 Subject: [PATCH 046/104] Updated prior.py to include UniformComponenChirpMassPrior --- src/jimgw/prior.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 9f52fdbb..c182ee16 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -367,6 +367,22 @@ def __init__( ) +@jaxtyped(typechecker=typechecker) +class UniformComponenChirpMassPrior(PowerLawPrior): + """ + A prior in the range [xmin, xmax) for chirp mass which assumes the + component mass to be uniform. + + p(\cal M) ~ \cal M + """ + + def __repr__(self): + return f"UniformComponentChirpMass(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float, parameter_names: list[str]): + super().__init__(xmin, xmax, 1.0, parameter_names) + + def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: if prior.composite: if isinstance(prior.base_prior, list): From 0b2c7cf296b152753b0f757dec7333f97dd668c9 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 21:16:29 -0400 Subject: [PATCH 047/104] Fix priors --- src/jimgw/prior.py | 26 ++++++++++++-------------- src/jimgw/transforms.py | 18 ------------------ 2 files changed, 12 insertions(+), 32 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index dc6a492b..b09e8cc7 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,7 +12,6 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - ArcCosineTransform, ) @@ -30,7 +29,7 @@ class Prior(Distribution): composite: bool = False @property - def n_dim(self): + def n_dim(self) -> int: return len(self.parameter_names) def __init__(self, parameter_names: list[str]): @@ -270,12 +269,13 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, [f"cos_{self.parameter_names[0]}"]), + CosinePrior([f"{self.parameter_names[0]}"]), [ - ArcCosineTransform( - ([f"cos_{self.parameter_names[0]}"], [self.parameter_names[0]]) + OffsetTransform( + (([f"{self.parameter_names[0]}"], [f"{self.parameter_names[0]}"])), + jnp.pi / 2, ) - ], + ] ) @@ -292,7 +292,7 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, f"sin_{self.parameter_names[0]}"), + UniformPrior(-1.0, 1.0, [f"sin_{self.parameter_names[0]}"]), [ ArcSineTransform( ([f"sin_{self.parameter_names[0]}"], [self.parameter_names[0]]) @@ -308,14 +308,12 @@ def __repr__(self): return f"UniformSphere(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): - assert ( - len(parameter_names) == 1 - ), "UniformSphere only takes the name of the vector" - parameter_names = parameter_names[0] + self.parameter_names = parameter_names + assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" self.parameter_names = [ - f"{parameter_names}_mag", - f"{parameter_names}_theta", - f"{parameter_names}_phi", + f"{self.parameter_names[0]}_mag", + f"{self.parameter_names[0]}_theta", + f"{self.parameter_names[0]}_phi", ] super().__init__( [ diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 5022fc88..3ea71e9c 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -157,21 +157,3 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) - -class ArcCosineTransform(UnivariateTransform): - """ - ArcCosine transformation - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - self.transform_func = lambda x: jnp.arccos(x) From d917f96da3b764980885fc811b7120b4a9183342 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 26 Jul 2024 21:17:57 -0400 Subject: [PATCH 048/104] Add prior tests --- test/test_prior.py | 69 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 13 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index 6431d7d0..48902b3b 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -1,23 +1,66 @@ from jimgw.prior import * +import scipy.stats as stats class TestUnivariatePrior: def test_logistic(self): - p = LogitTransform() - - def test_uniform(self): - p = UniformPrior(0.0, 10.0, ["x"]) - samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) + p = LogisticDistribution(["x"]) + # Check that the log_prob is finite + samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) - assert jnp.allclose(log_prob, -jnp.log(10.0)) + assert jnp.all(jnp.isfinite(log_prob)) + # Cross-check log_prob with scipy.stats.logistic + x = jnp.linspace(-10.0, 10.0, 1000) + assert jnp.allclose(jax.vmap(p.log_prob)(p.add_name(x[None])), stats.logistic.logpdf(x)) + def test_standard_normal(self): + p = StandardNormalDistribution(["x"]) + # Check that the log_prob is finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Cross-check log_prob with scipy.stats.norm + x = jnp.linspace(-10.0, 10.0, 1000) + assert jnp.allclose(jax.vmap(p.log_prob)(p.add_name(x[None])), stats.norm.logpdf(x)) -class TestPriorOperations: - def test_combine(self): - raise NotImplementedError + def test_uniform(self): + xmin, xmax = -10.0, 10.0 + p = UniformPrior(xmin, xmax, ["x"]) + # Check that the log_prob is correct in the support + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin))) - def test_sequence(self): - raise NotImplementedError + def test_sine(self): + p = SinePrior(["x"]) + # Check that the log_prob is finite + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Check that the log_prob is correct in the support + x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) + y = jax.vmap(p.base_prior.base_prior.transform)(x) + y = jax.vmap(p.base_prior.transform)(y) + y = jax.vmap(p.transform)(y) + assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.sin(y['x'])/2.0)) + + def test_cosine(self): + p = CosinePrior(["x"]) + # Check that the log_prob is finite + samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) + # Check that the log_prob is correct in the support + x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) + y = jax.vmap(p.base_prior.transform)(x) + y = jax.vmap(p.transform)(y) + assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.cos(y['x'])/2.0)) - def test_factor(self): - raise NotImplementedError + def test_uniform_sphere(self): + p = UniformSpherePrior(["x"]) + # Check that the log_prob is finite + samples = {} + for i in range(3): + samples.update(trace_prior_parent(p)[i].sample(jax.random.PRNGKey(0), 10000)) + log_prob = jax.vmap(p.log_prob)(samples) + assert jnp.all(jnp.isfinite(log_prob)) From 142fe5446db1044a3c6385060a963afbfa65cef3 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:19:25 -0400 Subject: [PATCH 049/104] Set constraint on powerlaw prior input --- src/jimgw/prior.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index c1baca3c..b349d074 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -348,6 +348,8 @@ def __init__( self.xmax = xmax self.xmin = xmin self.alpha = alpha + assert self.xmin < self.xmax, "xmin must be less than xmax" + assert self.xmin > 0.0, "x must be positive" if self.alpha == -1.0: transform = ParetoTransform( (self.parameter_names, self.parameter_names), xmin, xmax From fa6a6e25aa472e0d19233aca66c8daca9e60214e Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:20:23 -0400 Subject: [PATCH 050/104] Added test_power_law --- test/test_prior.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/test_prior.py b/test/test_prior.py index 48902b3b..0f9e4566 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -64,3 +64,33 @@ def test_uniform_sphere(self): samples.update(trace_prior_parent(p)[i].sample(jax.random.PRNGKey(0), 10000)) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) + + def test_power_law(self): + from bilby.core.prior.analytical import PowerLaw + def func(alpha): + xmin = 1.0 + xmax = 20.0 + p = PowerLawPrior(xmin, xmax, alpha, ["x"]) + # Check that all the samples are finite + powerlaw_samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) + + # Check that all the log_probs are finite + samples = trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000)['x'] + base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) + assert jnp.all(jnp.isfinite(log_p)) + + # Check that the log_prob is correct in the support + samples = jnp.linspace(xmin, xmax, 1000) + transformed_samples = jax.vmap(p.transform)({'x': samples})['x'] + assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), PowerLaw(alpha, xmin, xmax).ln_prob(transformed_samples)) + + # Test Pareto Transform + func(-1.0) + # Test other values of alpha + positive_alpha = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] + for alpha_val in positive_alpha: + func(alpha_val) + negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] + for alpha_val in negative_alpha: + func(alpha_val) From a8c06737056cf32ef7347f38af3ceceeabc3818d Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:22:25 -0400 Subject: [PATCH 051/104] Updated ParetoTransform to avoid divide by zero --- src/jimgw/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4c1380cf..48e29a06 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -207,5 +207,5 @@ def __init__( self.xmin = xmin self.xmax = xmax self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax / self.xmin) + x * jnp.log(self.xmax) - x * jnp.log(self.xmin) ) From 6ce5b369ee8cb7eead770412341dda61141b2c3b Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:29:19 -0400 Subject: [PATCH 052/104] Revert update on ParetoTransform --- src/jimgw/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 48e29a06..4c1380cf 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -207,5 +207,5 @@ def __init__( self.xmin = xmin self.xmax = xmax self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax) - x * jnp.log(self.xmin) + x * jnp.log(self.xmax / self.xmin) ) From e9511b9aa64e5dd7851760b49f9dbf0bdfe87061 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:29:49 -0400 Subject: [PATCH 053/104] Updated test_prior.py --- test/test_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prior.py b/test/test_prior.py index 0f9e4566..83392751 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -78,7 +78,7 @@ def func(alpha): # Check that all the log_probs are finite samples = trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000)['x'] base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) - assert jnp.all(jnp.isfinite(log_p)) + assert jnp.all(jnp.isfinite(base_log_p)) # Check that the log_prob is correct in the support samples = jnp.linspace(xmin, xmax, 1000) From 54082eaa7707c115e90ee842f5d791a12aa80bfc Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:00:06 -0400 Subject: [PATCH 054/104] Updated test_prior.py --- test/test_prior.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index 83392751..cee29934 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -68,8 +68,8 @@ def test_uniform_sphere(self): def test_power_law(self): from bilby.core.prior.analytical import PowerLaw def func(alpha): - xmin = 1.0 - xmax = 20.0 + xmin = 0.05 + xmax = 10.0 p = PowerLawPrior(xmin, xmax, alpha, ["x"]) # Check that all the samples are finite powerlaw_samples = p.sample(jax.random.PRNGKey(0), 10000) @@ -81,10 +81,15 @@ def func(alpha): assert jnp.all(jnp.isfinite(base_log_p)) # Check that the log_prob is correct in the support - samples = jnp.linspace(xmin, xmax, 1000) + samples = jnp.linspace(-10.0, 10.0, 1000) transformed_samples = jax.vmap(p.transform)({'x': samples})['x'] - assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), PowerLaw(alpha, xmin, xmax).ln_prob(transformed_samples)) - + # cut off the samples that are outside the support + samples = samples[transformed_samples >= xmin] + transformed_samples = transformed_samples[transformed_samples >= xmin] + samples = samples[transformed_samples <= xmax] + transformed_samples = transformed_samples[transformed_samples <= xmax] + assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), PowerLaw(alpha, xmin, xmax).ln_prob(transformed_samples), atol=1e-4) + # Test Pareto Transform func(-1.0) # Test other values of alpha @@ -93,4 +98,4 @@ def func(alpha): func(alpha_val) negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] for alpha_val in negative_alpha: - func(alpha_val) + func(alpha_val) \ No newline at end of file From 9ad64a1ca0bf0a2c0ea945b5cae3f02c38c07199 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:05:02 -0400 Subject: [PATCH 055/104] Updated test_prior.py --- test/test_prior.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_prior.py b/test/test_prior.py index cee29934..c0106fa9 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -5,6 +5,9 @@ class TestUnivariatePrior: def test_logistic(self): p = LogisticDistribution(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -15,6 +18,9 @@ def test_logistic(self): def test_standard_normal(self): p = StandardNormalDistribution(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -26,6 +32,9 @@ def test_standard_normal(self): def test_uniform(self): xmin, xmax = -10.0, 10.0 p = UniformPrior(xmin, xmax, ["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is correct in the support samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -33,6 +42,9 @@ def test_uniform(self): def test_sine(self): p = SinePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -46,6 +58,9 @@ def test_sine(self): def test_cosine(self): p = CosinePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -58,6 +73,10 @@ def test_cosine(self): def test_uniform_sphere(self): p = UniformSpherePrior(["x"]) + # Check that all the samples are finite + samples = p.sample(jax.random.PRNGKey(0), 10000) + assert jnp.all(jnp.isfinite(samples['x'])) + # Check that the log_prob is finite samples = {} for i in range(3): From ce437c417d5d9279f0fed75dce44c2f179edca62 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:12:32 -0400 Subject: [PATCH 056/104] Updated test_prior.py --- test/test_prior.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index c0106fa9..7cd4d8f8 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -75,8 +75,9 @@ def test_uniform_sphere(self): p = UniformSpherePrior(["x"]) # Check that all the samples are finite samples = p.sample(jax.random.PRNGKey(0), 10000) - assert jnp.all(jnp.isfinite(samples['x'])) - + assert jnp.all(jnp.isfinite(samples['x_mag'])) + assert jnp.all(jnp.isfinite(samples['x_theta'])) + assert jnp.all(jnp.isfinite(samples['x_phi'])) # Check that the log_prob is finite samples = {} for i in range(3): From 18d5184837c2adc9e2db7df0c9d5b071ac406cf6 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:17:26 -0400 Subject: [PATCH 057/104] Updated test_prior.py --- test/test_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prior.py b/test/test_prior.py index 7cd4d8f8..cdfdbae0 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -96,7 +96,7 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000)['x'] + samples = (trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000))['x'] base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) assert jnp.all(jnp.isfinite(base_log_p)) From 7c05d14df30b4ba8894967da0f1334ba95419ce2 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 13:27:02 -0400 Subject: [PATCH 058/104] Updated test_prior.py --- test/test_prior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index cdfdbae0..11206877 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -81,7 +81,7 @@ def test_uniform_sphere(self): # Check that the log_prob is finite samples = {} for i in range(3): - samples.update(trace_prior_parent(p)[i].sample(jax.random.PRNGKey(0), 10000)) + samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000)) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) @@ -96,7 +96,7 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = (trace_prior_parent(p)[0].sample(jax.random.PRNGKey(0), 10000))['x'] + samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x'] base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) assert jnp.all(jnp.isfinite(base_log_p)) From f647bf0155a2cb8b6a4cfe460cd71c348e344e31 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 29 Jul 2024 13:31:38 -0400 Subject: [PATCH 059/104] Remove unnecessary test --- test/test_prior.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index 11206877..d9dbe27a 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -5,9 +5,6 @@ class TestUnivariatePrior: def test_logistic(self): p = LogisticDistribution(["x"]) - # Check that all the samples are finite - samples = p.sample(jax.random.PRNGKey(0), 10000) - assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -18,9 +15,6 @@ def test_logistic(self): def test_standard_normal(self): p = StandardNormalDistribution(["x"]) - # Check that all the samples are finite - samples = p.sample(jax.random.PRNGKey(0), 10000) - assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) @@ -118,4 +112,4 @@ def func(alpha): func(alpha_val) negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0] for alpha_val in negative_alpha: - func(alpha_val) \ No newline at end of file + func(alpha_val) From 6171ce5d7d00935eedb1bebb0067768b1364ca57 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 29 Jul 2024 13:32:24 -0400 Subject: [PATCH 060/104] Reformat --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index b349d074..2d22c368 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -278,7 +278,7 @@ def __init__(self, parameter_names: list[str]): (([f"{self.parameter_names[0]}"], [f"{self.parameter_names[0]}"])), jnp.pi / 2, ) - ] + ], ) From ab7447be5d5ddae1f11768248ef825975a673bda Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 29 Jul 2024 15:15:43 -0400 Subject: [PATCH 061/104] Change naming --- src/jimgw/prior.py | 142 +++++++++++++++------------------------------ 1 file changed, 46 insertions(+), 96 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 2d22c368..3697df2f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -250,11 +250,25 @@ def __init__( super().__init__( LogisticDistribution(self.parameter_names), [ - LogitTransform((self.parameter_names, self.parameter_names)), + LogitTransform( + ( + [f"{self.parameter_names[0]}_base"], + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + ) + ), ScaleTransform( - (self.parameter_names, self.parameter_names), xmax - xmin + ( + ( + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + [f"{self.parameter_names[0]}-({xmin})"], + ), + xmax - xmin, + ), + ), + OffsetTransform( + ([f"{self.parameter_names[0]}-({xmin})"], self.parameter_names), + xmin, ), - OffsetTransform((self.parameter_names, self.parameter_names), xmin), ], ) @@ -275,7 +289,12 @@ def __init__(self, parameter_names: list[str]): CosinePrior([f"{self.parameter_names[0]}"]), [ OffsetTransform( - (([f"{self.parameter_names[0]}"], [f"{self.parameter_names[0]}"])), + ( + ( + [f"{self.parameter_names[0]}-pi/2"], + [f"{self.parameter_names[0]}"], + ) + ), jnp.pi / 2, ) ], @@ -295,10 +314,10 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "CosinePrior needs to be 1D distributions" super().__init__( - UniformPrior(-1.0, 1.0, [f"sin_{self.parameter_names[0]}"]), + UniformPrior(-1.0, 1.0, [f"sin({self.parameter_names[0]})"]), [ ArcSineTransform( - ([f"sin_{self.parameter_names[0]}"], [self.parameter_names[0]]) + ([f"sin({self.parameter_names[0]})"], [self.parameter_names[0]]) ) ], ) @@ -308,7 +327,7 @@ def __init__(self, parameter_names: list[str]): class UniformSpherePrior(CombinePrior): def __repr__(self): - return f"UniformSphere(parameter_names={self.parameter_names})" + return f"UniformSpherePrior(parameter_names={self.parameter_names})" def __init__(self, parameter_names: list[str], **kwargs): self.parameter_names = parameter_names @@ -334,7 +353,7 @@ class PowerLawPrior(SequentialTransformPrior): alpha: float def __repr__(self): - return f"PowerLaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" + return f"PowerLawPrior(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.parameter_names})" def __init__( self, @@ -352,35 +371,45 @@ def __init__( assert self.xmin > 0.0, "x must be positive" if self.alpha == -1.0: transform = ParetoTransform( - (self.parameter_names, self.parameter_names), xmin, xmax + ([f"{self.parameter_names[0]}_before_transform"], self.parameter_names), + xmin, + xmax, ) else: transform = PowerLawTransform( - (self.parameter_names, self.parameter_names), xmin, xmax, alpha + ([f"{self.parameter_names[0]}_before_transform"], self.parameter_names), + xmin, + xmax, + alpha, ) super().__init__( LogisticDistribution(self.parameter_names), [ - LogitTransform((self.parameter_names, self.parameter_names)), + LogitTransform( + ( + [f"{self.parameter_names[0]}_base"], + [f"{self.parameter_names[0]}_before_transform"], + ) + ), transform, ], ) @jaxtyped(typechecker=typechecker) -class UniformComponenChirpMassPrior(PowerLawPrior): +class UniformInComponentsChirpMassPrior(PowerLawPrior): """ A prior in the range [xmin, xmax) for chirp mass which assumes the - component mass to be uniform. + component masses to be uniformly distributed. - p(\cal M) ~ \cal M + p(M_c) ~ M_c """ def __repr__(self): - return f"UniformComponentChirpMass(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" - def __init__(self, xmin: float, xmax: float, parameter_names: list[str]): - super().__init__(xmin, xmax, 1.0, parameter_names) + def __init__(self, xmin: float, xmax: float): + super().__init__(xmin, xmax, 1.0, ["M_c"]) def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: @@ -588,85 +617,6 @@ def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: # return output + jnp.log(jnp.sin(zenith)) -# @jaxtyped(typechecker=typechecker) -# class PowerLaw(Prior): -# """ -# A prior following the power-law with alpha in the range [xmin, xmax). -# p(x) ~ x^{\alpha} -# """ - -# xmin: float = 0.0 -# xmax: float = 1.0 -# alpha: float = 0.0 -# normalization: float = 1.0 - -# def __repr__(self): -# return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" - -# def __init__( -# self, -# xmin: float, -# xmax: float, -# alpha: Union[Int, float], -# naming: list[str], -# transforms: dict[str, tuple[str, Callable]] = {}, -# **kwargs, -# ): -# super().__init__(naming, transforms) -# if alpha < 0.0: -# assert xmin > 0.0, "With negative alpha, xmin must > 0" -# assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" -# self.xmax = xmax -# self.xmin = xmin -# self.alpha = alpha -# if alpha == -1: -# self.normalization = float(1.0 / jnp.log(self.xmax / self.xmin)) -# else: -# self.normalization = (1 + self.alpha) / ( -# self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) -# ) - -# def sample( -# self, rng_key: PRNGKeyArray, n_samples: int -# ) -> dict[str, Float[Array, " n_samples"]]: -# """ -# Sample from a power-law distribution. - -# Parameters -# ---------- -# rng_key : PRNGKeyArray -# A random key to use for sampling. -# n_samples : int -# The number of samples to draw. - -# Returns -# ------- -# samples : dict -# Samples from the distribution. The keys are the names of the parameters. - -# """ -# q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) -# if self.alpha == -1: -# samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin)) -# else: -# samples = ( -# self.xmin ** (1.0 + self.alpha) -# + q_samples -# * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) -# ) ** (1.0 / (1.0 + self.alpha)) -# return self.add_name(samples[None]) - -# def log_prob(self, x: dict[str, Float]) -> Float: -# variable = x[self.naming[0]] -# log_in_range = jnp.where( -# (variable >= self.xmax) | (variable <= self.xmin), -# jnp.zeros_like(variable) - jnp.inf, -# jnp.zeros_like(variable), -# ) -# log_p = self.alpha * jnp.log(variable) + jnp.log(self.normalization) -# return log_p + log_in_range - - # @jaxtyped(typechecker=typechecker) # class Exponential(Prior): # """ From 1b670566f24efa86d2f824e21c4d4079533c856a Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:25:02 -0400 Subject: [PATCH 062/104] Removed bilby --- test/test_prior.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/test_prior.py b/test/test_prior.py index d9dbe27a..22ae827e 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -80,7 +80,13 @@ def test_uniform_sphere(self): assert jnp.all(jnp.isfinite(log_prob)) def test_power_law(self): - from bilby.core.prior.analytical import PowerLaw + def powerlaw_log_pdf(x, alpha, xmin, xmax): + if alpha == -1.0: + normalization = 1./(jnp.log(xmax) - jnp.log(xmin)) + else: + normalization = (1.0 + alpha) / (xmax**(1.0 + alpha) - xmin**(1.0 + alpha)) + return jnp.log(normalization) + alpha * jnp.log(x) + def func(alpha): xmin = 0.05 xmax = 10.0 @@ -102,11 +108,13 @@ def func(alpha): transformed_samples = transformed_samples[transformed_samples >= xmin] samples = samples[transformed_samples <= xmax] transformed_samples = transformed_samples[transformed_samples <= xmax] - assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), PowerLaw(alpha, xmin, xmax).ln_prob(transformed_samples), atol=1e-4) + # log pdf of powerlaw + assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) # Test Pareto Transform func(-1.0) # Test other values of alpha + print("Testing PowerLawPrior") positive_alpha = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] for alpha_val in positive_alpha: func(alpha_val) From 9d59c4ddd75cd2583795c5403aeae1895833305d Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:00:48 -0400 Subject: [PATCH 063/104] Move unit test to unit directory And create integration directory --- docs/tutorials/naming_system.md | 0 test/integration/test_GW150914.py | 0 test/{ => unit}/test_detector.py | 0 test/{ => unit}/test_prior.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/tutorials/naming_system.md create mode 100644 test/integration/test_GW150914.py rename test/{ => unit}/test_detector.py (100%) rename test/{ => unit}/test_prior.py (100%) diff --git a/docs/tutorials/naming_system.md b/docs/tutorials/naming_system.md new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_detector.py b/test/unit/test_detector.py similarity index 100% rename from test/test_detector.py rename to test/unit/test_detector.py diff --git a/test/test_prior.py b/test/unit/test_prior.py similarity index 100% rename from test/test_prior.py rename to test/unit/test_prior.py From b3110dfd33d6df85cd69e1fa4b2face55979076c Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:27:17 -0400 Subject: [PATCH 064/104] Update priors naming issue add test_GW150914.py --- .github/workflows/python-package.yml | 6 +- src/jimgw/prior.py | 12 ++- test/integration/test_GW150914.py | 116 +++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 12 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 4806efed..1b45e23c 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -3,11 +3,7 @@ name: Python package -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] +on: push, pull_request jobs: build: diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3697df2f..f224b119 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -248,7 +248,7 @@ def __init__( self.xmax = xmax self.xmin = xmin super().__init__( - LogisticDistribution(self.parameter_names), + LogisticDistribution([f"{self.parameter_names[0]}_base"]), [ LogitTransform( ( @@ -258,12 +258,10 @@ def __init__( ), ScaleTransform( ( - ( - [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], - [f"{self.parameter_names[0]}-({xmin})"], - ), - xmax - xmin, + [f"({self.parameter_names[0]}-({xmin}))/{(xmax-xmin)}"], + [f"{self.parameter_names[0]}-({xmin})"], ), + xmax - xmin, ), OffsetTransform( ([f"{self.parameter_names[0]}-({xmin})"], self.parameter_names), @@ -383,7 +381,7 @@ def __init__( alpha, ) super().__init__( - LogisticDistribution(self.parameter_names), + LogisticDistribution([f"{self.parameter_names[0]}_base"]), [ LogitTransform( ( diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index e69de29b..816647fa 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -0,0 +1,116 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior( + 0.125, + 1.0, + parameter_names=["q"], +) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +cos_iota_prior = UniformPrior( + -1.0, + 1.0, + parameter_names=["cos_iota"], +) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +sin_dec_prior = UniformPrior( + -1.0, + 1.0, + parameter_names=["sin_dec"], +) + +prior = CombinePrior( + [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] +) +likelihood = TransientLikelihoodFD( + [H1, L1], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) From 024c5c3fa6bb9dd69ea5f2dd573046a6d849baac Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:53:00 -0400 Subject: [PATCH 065/104] base parameter names seem working --- src/jimgw/jim.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 08b82b05..1a694172 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from jimgw.base import LikelihoodBase -from jimgw.prior import Prior +from jimgw.prior import Prior, trace_prior_parent class Jim(object): @@ -26,11 +26,18 @@ def __init__( self, likelihood: LikelihoodBase, prior: Prior, - parameter_names: list[str], + parameter_names: list[str] | None = None, **kwargs, ): self.likelihood = likelihood self.prior = prior + if parameter_names is None: + print("No parameter names provided. Will try to trace the prior.") + parents = [] + trace_prior_parent(prior, parents) + parameter_names = [] + for parent in parents: + parameter_names.extend(parent.parameter_names) self.parameter_names = parameter_names seed = kwargs.get("seed", 0) From ac43d54ffb97089513b71e269fcf8bde954d7839 Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:55:44 -0400 Subject: [PATCH 066/104] fix cosine naming error --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f224b119..6a8f10ca 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -284,7 +284,7 @@ def __init__(self, parameter_names: list[str]): self.parameter_names = parameter_names assert self.n_dim == 1, "SinePrior needs to be 1D distributions" super().__init__( - CosinePrior([f"{self.parameter_names[0]}"]), + CosinePrior([f"{self.parameter_names[0]}-pi/2"]), [ OffsetTransform( ( From 151c37df64d147bf5b7e5ddfffa9dd29f0e6d221 Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 17:59:20 -0400 Subject: [PATCH 067/104] prior test now all pass --- .github/workflows/python-package.yml | 2 +- test/unit/test_prior.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 1b45e23c..d70c2ae1 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -3,7 +3,7 @@ name: Python package -on: push, pull_request +on: [push, pull_request] jobs: build: diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 22ae827e..20d71de4 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -96,20 +96,20 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x'] - base_log_p = jax.vmap(p.log_prob, [0])({'x':samples}) + samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base'] + base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples}) assert jnp.all(jnp.isfinite(base_log_p)) # Check that the log_prob is correct in the support samples = jnp.linspace(-10.0, 10.0, 1000) - transformed_samples = jax.vmap(p.transform)({'x': samples})['x'] + transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x'] # cut off the samples that are outside the support samples = samples[transformed_samples >= xmin] transformed_samples = transformed_samples[transformed_samples >= xmin] samples = samples[transformed_samples <= xmax] transformed_samples = transformed_samples[transformed_samples <= xmax] # log pdf of powerlaw - assert jnp.allclose(jax.vmap(p.log_prob)({'x':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) + assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4) # Test Pareto Transform func(-1.0) From 576ce7cd14cda25522c4dda9cfce906b2f4570f1 Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 29 Jul 2024 18:20:29 -0400 Subject: [PATCH 068/104] Fix likelihood sampling issue due to parameter name transformation --- test/integration/test_GW150914.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 816647fa..2650cd1c 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -36,24 +36,25 @@ q_prior = UniformPrior( 0.125, 1.0, - parameter_names=["q"], + parameter_names=["q"], # Need name transformation in likelihood to work ) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +# Current likelihood sampling will fail and give nan because of large number dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) cos_iota_prior = UniformPrior( -1.0, 1.0, - parameter_names=["cos_iota"], + parameter_names=["cos_iota"], # Need name transformation in likelihood to work ) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) sin_dec_prior = UniformPrior( -1.0, 1.0, - parameter_names=["sin_dec"], + parameter_names=["sin_dec"], # Need name transformation in likelihood to work ) prior = CombinePrior( From 7625dbe04eab2c436d2ae27c91d2ee3148c98b0d Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 30 Jul 2024 14:51:56 -0400 Subject: [PATCH 069/104] Instead of defining univariate and multivariate, favor bijective versus nonbijective --- src/jimgw/transforms.py | 87 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 9 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4c1380cf..902b5b30 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -10,12 +10,10 @@ class Transform(ABC): """ Base class for transform. - - The idea of transform should be used on distribtuion, + The purpose of this class is purely for keeping name """ name_mapping: tuple[list[str], list[str]] - transform_func: Callable[[dict[str, Float]], dict[str, Float]] def __init__( self, @@ -23,6 +21,17 @@ def __init__( ): self.name_mapping = name_mapping + def propagate_name(self, x: list[str]) -> list[str]: + input_set = set(x) + from_set = set(self.name_mapping[0]) + to_set = set(self.name_mapping[1]) + return list(input_set - from_set | to_set) + +class BijectiveTransform(Transform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] + def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: return self.transform(x) @@ -62,13 +71,64 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: The transformed dictionary. """ raise NotImplementedError + + @abstractmethod + def inverse(self, y: dict[str, Float]) -> dict[str, Float]: + """ + Inverse transform the input y to original coordinate x. - def propagate_name(self, x: list[str]) -> list[str]: - input_set = set(x) - from_set = set(self.name_mapping[0]) - to_set = set(self.name_mapping[1]) - return list(input_set - from_set | to_set) + Parameters + ---------- + y : dict[str, Float] + The transformed dictionary. + + Returns + ------- + x : dict[str, Float] + The original dictionary. + """ + raise NotImplementedError + + @abstractmethod + def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + """ + Pull back the input y to original coordinate x and return the log Jacobian determinant. + + Parameters + ---------- + y : dict[str, Float] + The transformed dictionary. + + Returns + ------- + x : dict[str, Float] + The original dictionary. + log_det : Float + The log Jacobian determinant. + """ + raise NotImplementedError + +class NonBijectiveTransform(Transform): + + def __call__(self, x: dict[str, Float]) -> dict[str, Float]: + return self.forward(x) + + @abstractmethod + def forward(self, x: dict[str, Float]) -> dict[str, Float]: + """ + Push forward the input x to transformed coordinate y. + Parameters + ---------- + x : dict[str, Float] + The input dictionary. + + Returns + ------- + y : dict[str, Float] + The transformed dictionary. + """ + raise NotImplementedError class UnivariateTransform(Transform): @@ -94,7 +154,7 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: return x -class ScaleTransform(UnivariateTransform): +class ScaleTransform(BijectiveTransform): scale: Float def __init__( @@ -105,6 +165,15 @@ def __init__( super().__init__(name_mapping) self.scale = scale self.transform_func = lambda x: x * self.scale + self.inverse_transform_func = lambda x: x / self.scale + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + input_params = x.pop(self.name_mapping[0][0]) + assert_rank(input_params, 0) + output_params = self.transform_func(input_params) + jacobian = jnp.log(jnp.abs(self.scale)) + x[self.name_mapping[1][0]] = output_params + return x, jacobian class OffsetTransform(UnivariateTransform): From 9ac722f57dbfed388bd0ab2aa2244f9463a5cc85 Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 30 Jul 2024 16:29:54 -0400 Subject: [PATCH 070/104] Adding some transform, should be working --- src/jimgw/transforms.py | 188 ++++++++++++++++++---------------------- 1 file changed, 84 insertions(+), 104 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 902b5b30..904fe225 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp from chex import assert_rank -from jaxtyping import Float +from jaxtyping import Float, Array class Transform(ABC): @@ -29,13 +29,12 @@ def propagate_name(self, x: list[str]) -> list[str]: class BijectiveTransform(Transform): - transform_func: Callable[[dict[str, Float]], dict[str, Float]] - inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] + transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + inverse_transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: return self.transform(x) - @abstractmethod def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. @@ -53,9 +52,12 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - raise NotImplementedError + input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) + output_params = self.transform_func(input_params) + jacobian = jnp.array(jax.jacfwd(self.transform_func)(input_params)) + jax.tree.map(lambda key, value: x.update({key: value}), self.name_mapping[1], output_params) + return x, jnp.log(jnp.linalg.det(jacobian)) - @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ Push forward the input x to transformed coordinate y. @@ -70,9 +72,11 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: y : dict[str, Float] The transformed dictionary. """ - raise NotImplementedError + input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) + output_params = self.transform_func(input_params) + jax.tree.map(lambda key, value: x.update({key: value}), self.name_mapping[1], output_params) + return x - @abstractmethod def inverse(self, y: dict[str, Float]) -> dict[str, Float]: """ Inverse transform the input y to original coordinate x. @@ -87,9 +91,12 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: x : dict[str, Float] The original dictionary. """ - raise NotImplementedError + output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) + input_params = self.inverse_transform_func(output_params) + jacobian = jnp.array(jax.jacfwd(self.inverse_transform_func)(output_params)) + jax.tree.map(lambda key, value: y.update({key: value}), self.name_mapping[0], input_params) + return y, jnp.log(jnp.linalg.det(jacobian)) - @abstractmethod def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Pull back the input y to original coordinate x and return the log Jacobian determinant. @@ -106,7 +113,10 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - raise NotImplementedError + output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) + input_params = self.inverse_transform_func(output_params) + jax.tree.map(lambda key, value: y.update({key: value}), self.name_mapping[0], input_params) + return y class NonBijectiveTransform(Transform): @@ -130,30 +140,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ raise NotImplementedError -class UnivariateTransform(Transform): - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - - def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - input_params = x.pop(self.name_mapping[0][0]) - assert_rank(input_params, 0) - output_params = self.transform_func(input_params) - jacobian = jax.jacfwd(self.transform_func)(input_params) - x[self.name_mapping[1][0]] = output_params - return x, jnp.log(jacobian) - - def forward(self, x: dict[str, Float]) -> dict[str, Float]: - input_params = x.pop(self.name_mapping[0][0]) - assert_rank(input_params, 0) - output_params = self.transform_func(input_params) - x[self.name_mapping[1][0]] = output_params - return x - - class ScaleTransform(BijectiveTransform): scale: Float @@ -164,19 +150,10 @@ def __init__( ): super().__init__(name_mapping) self.scale = scale - self.transform_func = lambda x: x * self.scale - self.inverse_transform_func = lambda x: x / self.scale - - def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - input_params = x.pop(self.name_mapping[0][0]) - assert_rank(input_params, 0) - output_params = self.transform_func(input_params) - jacobian = jnp.log(jnp.abs(self.scale)) - x[self.name_mapping[1][0]] = output_params - return x, jacobian - + self.transform_func = lambda x: [x[0] * self.scale] + self.inverse_transform_func = lambda x: [x[0] / self.scale] -class OffsetTransform(UnivariateTransform): +class OffsetTransform(BijectiveTransform): offset: Float def __init__( @@ -186,10 +163,11 @@ def __init__( ): super().__init__(name_mapping) self.offset = offset - self.transform_func = lambda x: x + self.offset + self.transform_func = lambda x: [x[0] + self.offset] + self.inverse_transform_func = lambda x: [x[0] - self.offset] -class LogitTransform(UnivariateTransform): +class LogitTransform(BijectiveTransform): """ Logit transform following @@ -205,10 +183,11 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + self.transform_func = lambda x: [1 / (1 + jnp.exp(-x[0]))] + self.inverse_transform_func = lambda x: [jnp.log(x[0] / (1 - x[0]))] -class ArcSineTransform(UnivariateTransform): +class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -225,56 +204,57 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: jnp.arcsin(x) - - -class PowerLawTransform(UnivariateTransform): - """ - PowerLaw transformation - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - xmin: Float - xmax: Float - alpha: Float - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - alpha: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.alpha = alpha - self.transform_func = lambda x: ( - self.xmin ** (1.0 + self.alpha) - + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) - ) ** (1.0 / (1.0 + self.alpha)) - - -class ParetoTransform(UnivariateTransform): - """ - Pareto transformation: Power law when alpha = -1 - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - xmin: Float, - xmax: Float, - ): - super().__init__(name_mapping) - self.xmin = xmin - self.xmax = xmax - self.transform_func = lambda x: self.xmin * jnp.exp( - x * jnp.log(self.xmax / self.xmin) - ) + self.inverse_transform_func = lambda x: jnp.sin(x) + + +# class PowerLawTransform(UnivariateTransform): +# """ +# PowerLaw transformation +# Parameters +# ---------- +# name_mapping : tuple[list[str], list[str]] +# The name mapping between the input and output dictionary. +# """ + +# xmin: Float +# xmax: Float +# alpha: Float + +# def __init__( +# self, +# name_mapping: tuple[list[str], list[str]], +# xmin: Float, +# xmax: Float, +# alpha: Float, +# ): +# super().__init__(name_mapping) +# self.xmin = xmin +# self.xmax = xmax +# self.alpha = alpha +# self.transform_func = lambda x: ( +# self.xmin ** (1.0 + self.alpha) +# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha)) +# ) ** (1.0 / (1.0 + self.alpha)) + + +# class ParetoTransform(UnivariateTransform): +# """ +# Pareto transformation: Power law when alpha = -1 +# Parameters +# ---------- +# name_mapping : tuple[list[str], list[str]] +# The name mapping between the input and output dictionary. +# """ + +# def __init__( +# self, +# name_mapping: tuple[list[str], list[str]], +# xmin: Float, +# xmax: Float, +# ): +# super().__init__(name_mapping) +# self.xmin = xmin +# self.xmax = xmax +# self.transform_func = lambda x: self.xmin * jnp.exp( +# x * jnp.log(self.xmax / self.xmin) +# ) From 935b31484e06ddbfa3b985e22a8303c0dc50b7a3 Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 30 Jul 2024 16:47:20 -0400 Subject: [PATCH 071/104] Refactor into NtoN and NtoM transform --- src/jimgw/prior.py | 5 ++-- src/jimgw/transforms.py | 52 ++++++++++++++++++++++++++++++----------- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 6a8f10ca..22985a8a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -8,6 +8,7 @@ from jimgw.transforms import ( Transform, + NtoNTransform, LogitTransform, ScaleTransform, OffsetTransform, @@ -149,7 +150,7 @@ class SequentialTransformPrior(Prior): """ base_prior: Prior - transforms: list[Transform] + transforms: list[NtoNTransform] def __repr__(self): return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" @@ -157,7 +158,7 @@ def __repr__(self): def __init__( self, base_prior: Prior, - transforms: list[Transform], + transforms: list[NtoNTransform], ): self.base_prior = base_prior diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 904fe225..4e62316e 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -27,13 +27,10 @@ def propagate_name(self, x: list[str]) -> list[str]: to_set = set(self.name_mapping[1]) return list(input_set - from_set | to_set) -class BijectiveTransform(Transform): - transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] - inverse_transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] +class NtoNTransform(Transform): - def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - return self.transform(x) + transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ @@ -55,7 +52,11 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) output_params = self.transform_func(input_params) jacobian = jnp.array(jax.jacfwd(self.transform_func)(input_params)) - jax.tree.map(lambda key, value: x.update({key: value}), self.name_mapping[1], output_params) + jax.tree.map( + lambda key, value: x.update({key: value}), + self.name_mapping[1], + output_params, + ) return x, jnp.log(jnp.linalg.det(jacobian)) def forward(self, x: dict[str, Float]) -> dict[str, Float]: @@ -74,9 +75,21 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) output_params = self.transform_func(input_params) - jax.tree.map(lambda key, value: x.update({key: value}), self.name_mapping[1], output_params) + jax.tree.map( + lambda key, value: x.update({key: value}), + self.name_mapping[1], + output_params, + ) return x - + + +class BijectiveTransform(NtoNTransform): + + inverse_transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + + def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + return self.transform(x) + def inverse(self, y: dict[str, Float]) -> dict[str, Float]: """ Inverse transform the input y to original coordinate x. @@ -94,7 +107,11 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) input_params = self.inverse_transform_func(output_params) jacobian = jnp.array(jax.jacfwd(self.inverse_transform_func)(output_params)) - jax.tree.map(lambda key, value: y.update({key: value}), self.name_mapping[0], input_params) + jax.tree.map( + lambda key, value: y.update({key: value}), + self.name_mapping[0], + input_params, + ) return y, jnp.log(jnp.linalg.det(jacobian)) def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: @@ -115,14 +132,21 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) input_params = self.inverse_transform_func(output_params) - jax.tree.map(lambda key, value: y.update({key: value}), self.name_mapping[0], input_params) + jax.tree.map( + lambda key, value: y.update({key: value}), + self.name_mapping[0], + input_params, + ) return y -class NonBijectiveTransform(Transform): - + +class NtoMTransform(Transform): + + transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " m_dim"]] + def __call__(self, x: dict[str, Float]) -> dict[str, Float]: return self.forward(x) - + @abstractmethod def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -140,6 +164,7 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ raise NotImplementedError + class ScaleTransform(BijectiveTransform): scale: Float @@ -153,6 +178,7 @@ def __init__( self.transform_func = lambda x: [x[0] * self.scale] self.inverse_transform_func = lambda x: [x[0] / self.scale] + class OffsetTransform(BijectiveTransform): offset: Float From 685670a3c97b3b0300ce6a7ef6351bd56375c3e1 Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 30 Jul 2024 17:06:02 -0400 Subject: [PATCH 072/104] Fix bugs in ArcSine --- src/jimgw/prior.py | 4 ++-- src/jimgw/transforms.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 22985a8a..f291b23a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -13,8 +13,8 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - PowerLawTransform, - ParetoTransform, + # PowerLawTransform, + # ParetoTransform, ) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4e62316e..ca759a07 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -229,8 +229,8 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: jnp.arcsin(x) - self.inverse_transform_func = lambda x: jnp.sin(x) + self.transform_func = lambda x: [jnp.arcsin(x[0])] + self.inverse_transform_func = lambda x: [jnp.sin(x[0])] # class PowerLawTransform(UnivariateTransform): From 39126f5cf51be98d7ac5d53ed3ba3ed4e37c5259 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Tue, 30 Jul 2024 17:41:48 -0400 Subject: [PATCH 073/104] Add inverse mass transform --- src/jimgw/single_event/utils.py | 54 +++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 3f0decce..870c70ae 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -61,7 +61,7 @@ def m1m2_to_Mq(m1: Float, m2: Float): return M_tot, q -def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): +def M_q_to_m1_m2(trans_M_tot: Float, trans_q: Float): """ Transforming the Total mass M and mass ratio q to the primary mass m1 and secondary mass m2. @@ -87,7 +87,7 @@ def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): return m1, m2 -def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: +def Mc_q_to_m1_m2(Mc: Float, q: Float) -> tuple[Float, Float]: """ Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and secondary mass m2. @@ -113,6 +113,56 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: return m1, m2 +def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass Mc + and mass ratio q. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + Mc : Float + Chirp mass. + q : Float + Mass ratio. + """ + M_tot = m1 + m2 + eta = m1 * m2 / M_tot ** 2 + Mc = M_tot * eta ** (3.0 / 5) + q = m2 / m1 + return Mc, q + + +def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the total mass M + and symmetric mass ratio eta. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M : Float + Total mass. + eta : Float + Symmetric mass ratio. + """ + M = m1 + m2 + eta = m1 * m2 / M ** 2 + return M, eta + + def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ Transforming the right ascension ra and declination dec to the polar angle From 46bd0442c6552a8967a0f9be388fa778a385bd4c Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 30 Jul 2024 23:27:24 -0400 Subject: [PATCH 074/104] Update sequential prior class and test_GW15014.py. --- src/jimgw/prior.py | 31 +++++++++++++++++-------------- test/integration/test_GW150914.py | 26 +++++++++----------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f291b23a..e23d0dfa 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -7,8 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped from jimgw.transforms import ( - Transform, - NtoNTransform, + BijectiveTransform, LogitTransform, ScaleTransform, OffsetTransform, @@ -61,7 +60,7 @@ def sample( ) -> dict[str, Float[Array, " n_samples"]]: raise NotImplementedError - def log_prob(self, x: dict[str, Array]) -> Float: + def log_prob(self, z: dict[str, Array]) -> Float: raise NotImplementedError @@ -99,7 +98,7 @@ def sample( samples = jnp.log(samples / (1 - samples)) return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Float]) -> Float: + def log_prob(self, z: dict[str, Float]) -> Float: variable = x[self.parameter_names[0]] return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) @@ -139,7 +138,7 @@ def sample( samples = jax.random.normal(rng_key, (n_samples,)) return self.add_name(samples[None]) - def log_prob(self, x: dict[str, Float]) -> Float: + def log_prob(self, z: dict[str, Float]) -> Float: variable = x[self.parameter_names[0]] return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) @@ -147,10 +146,12 @@ def log_prob(self, x: dict[str, Float]) -> Float: class SequentialTransformPrior(Prior): """ Transform a prior distribution by applying a sequence of transforms. + The space before the transform is named as x, + and the space after the transform is named as z """ base_prior: Prior - transforms: list[NtoNTransform] + transforms: list[BijectiveTransform] def __repr__(self): return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})" @@ -158,7 +159,7 @@ def __repr__(self): def __init__( self, base_prior: Prior, - transforms: list[NtoNTransform], + transforms: list[BijectiveTransform], ): self.base_prior = base_prior @@ -174,14 +175,16 @@ def sample( output = self.base_prior.sample(rng_key, n_samples) return jax.vmap(self.transform)(output) - def log_prob(self, x: dict[str, Float]) -> Float: + def log_prob(self, z: dict[str, Float]) -> Float: """ - log_prob has to be evaluated in the space of the base_prior. + Evaluating the probability of the transformed variable z. + This is what flowMC should sample from """ - output = self.base_prior.log_prob(x) - for transform in self.transforms: - x, log_jacobian = transform.transform(x) + output = 0 + for transform in reversed(self.transforms): + z, log_jacobian = transform.inverse(z) output -= log_jacobian + output += self.base_prior.log_prob(z) return output def transform(self, x: dict[str, Float]) -> dict[str, Float]: @@ -223,10 +226,10 @@ def sample( output.update(prior.sample(subkey, n_samples)) return output - def log_prob(self, x: dict[str, Float]) -> Float: + def log_prob(self, z: dict[str, Float]) -> Float: output = 0.0 for prior in self.base_prior: - output += prior.log_prob(x) + output += prior.log_prob(z) return output diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 2650cd1c..deb3fb98 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD @@ -33,10 +33,10 @@ L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior( +eta_prior = UniformPrior( 0.125, - 1.0, - parameter_names=["q"], # Need name transformation in likelihood to work + 0.25, + parameter_names=["eta"], # Need name transformation in likelihood to work ) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) @@ -44,32 +44,24 @@ dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -cos_iota_prior = UniformPrior( - -1.0, - 1.0, - parameter_names=["cos_iota"], # Need name transformation in likelihood to work -) +iota_prior = CosinePrior(parameter_names=["iota"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -sin_dec_prior = UniformPrior( - -1.0, - 1.0, - parameter_names=["sin_dec"], # Need name transformation in likelihood to work -) +dec_prior = SinePrior(parameter_names=["dec"]) prior = CombinePrior( [ Mc_prior, - q_prior, + eta_prior, s1z_prior, s2z_prior, dL_prior, t_c_prior, phase_c_prior, - cos_iota_prior, + iota_prior, psi_prior, ra_prior, - sin_dec_prior, + dec_prior, ] ) likelihood = TransientLikelihoodFD( From 8e1441b389560bcb07f6a137705aba6db384d074 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 08:52:16 -0400 Subject: [PATCH 075/104] correct sign errors and minior bugs --- src/jimgw/prior.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e23d0dfa..13ab622f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -99,7 +99,7 @@ def sample( return self.add_name(samples[None]) def log_prob(self, z: dict[str, Float]) -> Float: - variable = x[self.parameter_names[0]] + variable = z[self.parameter_names[0]] return -variable - 2 * jnp.log(1 + jnp.exp(-variable)) @@ -139,7 +139,7 @@ def sample( return self.add_name(samples[None]) def log_prob(self, z: dict[str, Float]) -> Float: - variable = x[self.parameter_names[0]] + variable = z[self.parameter_names[0]] return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) @@ -183,7 +183,7 @@ def log_prob(self, z: dict[str, Float]) -> Float: output = 0 for transform in reversed(self.transforms): z, log_jacobian = transform.inverse(z) - output -= log_jacobian + output += log_jacobian output += self.base_prior.log_prob(z) return output From dfdfffad0bc6847ad1519e7884426292b691ba80 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 09:43:00 -0400 Subject: [PATCH 076/104] Add mass transform --- src/jimgw/single_event/utils.py | 58 ++++++++++++++++----------------- src/jimgw/transforms.py | 44 ++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 30 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 870c70ae..4ee5c25e 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -87,14 +87,14 @@ def M_q_to_m1_m2(trans_M_tot: Float, trans_q: Float): return m1, m2 -def Mc_q_to_m1_m2(Mc: Float, q: Float) -> tuple[Float, Float]: +def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: """ - Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and + Transforming the chirp mass M_c and mass ratio q to the primary mass m1 and secondary mass m2. Parameters ---------- - Mc : Float + M_c : Float Chirp mass. q : Float Mass ratio. @@ -107,36 +107,36 @@ def Mc_q_to_m1_m2(Mc: Float, q: Float) -> tuple[Float, Float]: Secondary mass. """ eta = q / (1 + q) ** 2 - M_tot = Mc / eta ** (3.0 / 5) + M_tot = M_c / eta ** (3.0 / 5) m1 = M_tot / (1 + q) m2 = m1 * q return m1, m2 -def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: - """ - Transforming the primary mass m1 and secondary mass m2 to the chirp mass Mc - and mass ratio q. - - Parameters - ---------- - m1 : Float - Primary mass. - m2 : Float - Secondary mass. - - Returns - ------- - Mc : Float - Chirp mass. - q : Float - Mass ratio. - """ - M_tot = m1 + m2 - eta = m1 * m2 / M_tot ** 2 - Mc = M_tot * eta ** (3.0 / 5) - q = m2 / m1 - return Mc, q +def m1_m2_to_M_c_q(m1: Float, m2: Float) -> tuple[Float, Float]: + """ + Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c + and mass ratio q. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + """ + M_tot = m1 + m2 + eta = m1 * m2 / M_tot**2 + M_c = M_tot * eta ** (3.0 / 5) + q = m2 / m1 + return M_c, q def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: @@ -159,7 +159,7 @@ def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: Symmetric mass ratio. """ M = m1 + m2 - eta = m1 * m2 / M ** 2 + eta = m1 * m2 / M**2 return M, eta diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index ca759a07..e377afd2 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -3,9 +3,10 @@ import jax import jax.numpy as jnp -from chex import assert_rank from jaxtyping import Float, Array +from jimgw.single_event.utils import m1_m2_to_Mc_q, Mc_q_to_m1_m2 + class Transform(ABC): """ @@ -233,6 +234,47 @@ def __init__( self.inverse_transform_func = lambda x: [jnp.sin(x[0])] +class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): + """ + Transform component masses to chirp mass and mass ratio. + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + assert name_mapping == (["m_1", "m_2"], ["M_c", "q"]) + super().__init__(name_mapping) + self.transform_func = lambda x: m1_m2_to_Mc_q(x[0], x[1]) + self.inverse_transform_func = lambda x: Mc_q_to_m1_m2(x[0], x[1]) + + +def inverse(transform: BijectiveTransform) -> BijectiveTransform: + """ + Inverse the transform. + + Parameters + ---------- + transform : BijectiveTransform + The transform to be inverted. + + Returns + ------- + BijectiveTransform + The inverted transform. + """ + return BijectiveTransform( + name_mapping=transform.name_mapping, + transform_func=transform.inverse_transform_func, + inverse_transform_func=transform.transform_func, + ) + + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 9d87e58fa19fc4169ab8f4184e3c1738e817aa6f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 11:08:53 -0400 Subject: [PATCH 077/104] Add simplex transform --- src/jimgw/single_event/utils.py | 4 ++-- src/jimgw/transforms.py | 31 ++++++++++++++++--------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 4ee5c25e..3d07bb57 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -37,7 +37,7 @@ def inner_product( return 4.0 * jnp.real(trapezoid(integrand, dx=df)) -def m1m2_to_Mq(m1: Float, m2: Float): +def m1_m2_to_M_q(m1: Float, m2: Float): """ Transforming the primary mass m1 and secondary mass m2 to the Total mass M and mass ratio q. @@ -113,7 +113,7 @@ def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]: return m1, m2 -def m1_m2_to_M_c_q(m1: Float, m2: Float) -> tuple[Float, Float]: +def m1_m2_to_Mc_q(m1: Float, m2: Float) -> tuple[Float, Float]: """ Transforming the primary mass m1 and secondary mass m2 to the chirp mass M_c and mass ratio q. diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index e377afd2..be181fcd 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -196,7 +196,7 @@ def __init__( class LogitTransform(BijectiveTransform): """ - Logit transform following + Logit transformation Parameters ---------- @@ -254,25 +254,26 @@ def __init__( self.inverse_transform_func = lambda x: Mc_q_to_m1_m2(x[0], x[1]) -def inverse(transform: BijectiveTransform) -> BijectiveTransform: +class RectangleToTriangleTransform(BijectiveTransform): """ - Inverse the transform. + Transform a rectangle grid with bounds [0, 1] x [0, 1] to a triangle grid with vertices (0, 0), (1, 0), (0, 1), while preserving the area density. Parameters ---------- - transform : BijectiveTransform - The transform to be inverted. - - Returns - ------- - BijectiveTransform - The inverted transform. + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. """ - return BijectiveTransform( - name_mapping=transform.name_mapping, - transform_func=transform.inverse_transform_func, - inverse_transform_func=transform.transform_func, - ) + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: [ + 1 - jnp.sqrt(1 - x[0]), + x[1] * jnp.sqrt(1 - x[0]), + ] + self.inverse_transform_func = lambda x: [1 - (1 - x[0]) ** 2, x[1] / (1 - x[0])] # class PowerLawTransform(UnivariateTransform): From 52fa0eb28308aae155c566c677317b94645a83ea Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 11:10:23 -0400 Subject: [PATCH 078/104] Change transform to take dictionary as input for transform_func --- src/jimgw/transforms.py | 119 +++++++++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 39 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index ca759a07..13d66173 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -30,7 +30,7 @@ def propagate_name(self, x: list[str]) -> list[str]: class NtoNTransform(Transform): - transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + transform_func: Callable[[dict[str, Float]], dict[str, Float]] def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ @@ -49,15 +49,22 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) - output_params = self.transform_func(input_params) - jacobian = jnp.array(jax.jacfwd(self.transform_func)(input_params)) + x_copy = x.copy() + output_params = self.transform_func(x_copy) + jacobian = jax.jacfwd(self.transform_func)(x_copy) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log( + jnp.linalg.det(jacobian.reshape(int(jnp.sqrt(jacobian.size)), -1)) + ) jax.tree.map( - lambda key, value: x.update({key: value}), - self.name_mapping[1], - output_params, + lambda key: x_copy.pop(key), + self.name_mapping[0], ) - return x, jnp.log(jnp.linalg.det(jacobian)) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy, jacobian def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ @@ -73,19 +80,22 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: y : dict[str, Float] The transformed dictionary. """ - input_params = jax.tree.map(lambda key: x.pop(key), self.name_mapping[0]) - output_params = self.transform_func(input_params) + x_copy = x.copy() + output_params = self.transform_func(x_copy) jax.tree.map( - lambda key, value: x.update({key: value}), - self.name_mapping[1], - output_params, + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), ) - return x + return x_copy class BijectiveTransform(NtoNTransform): - inverse_transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]] + inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: return self.transform(x) @@ -104,15 +114,22 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: x : dict[str, Float] The original dictionary. """ - output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) - input_params = self.inverse_transform_func(output_params) - jacobian = jnp.array(jax.jacfwd(self.inverse_transform_func)(output_params)) + y_copy = y.copy() + output_params = self.inverse_transform_func(y_copy) + jacobian = jax.jacfwd(self.inverse_transform_func)(y_copy) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log( + jnp.linalg.det(jacobian.reshape(int(jnp.sqrt(jacobian.size)), -1)) + ) jax.tree.map( - lambda key, value: y.update({key: value}), - self.name_mapping[0], - input_params, + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), ) - return y, jnp.log(jnp.linalg.det(jacobian)) + return y_copy, jacobian def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ @@ -130,14 +147,17 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: log_det : Float The log Jacobian determinant. """ - output_params = jax.tree.map(lambda key: y.pop(key), self.name_mapping[1]) - input_params = self.inverse_transform_func(output_params) + y_copy = y.copy() + output_params = self.inverse_transform_func(y_copy) jax.tree.map( - lambda key, value: y.update({key: value}), - self.name_mapping[0], - input_params, + lambda key: y_copy.pop(key), + self.name_mapping[1], ) - return y + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy class NtoMTransform(Transform): @@ -175,8 +195,14 @@ def __init__( ): super().__init__(name_mapping) self.scale = scale - self.transform_func = lambda x: [x[0] * self.scale] - self.inverse_transform_func = lambda x: [x[0] / self.scale] + self.transform_func = lambda x: { + name_mapping[1][i]: x[name_mapping[0][i]] * self.scale + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: x[name_mapping[1][i]] / self.scale + for i in range(len(name_mapping[1])) + } class OffsetTransform(BijectiveTransform): @@ -189,9 +215,14 @@ def __init__( ): super().__init__(name_mapping) self.offset = offset - self.transform_func = lambda x: [x[0] + self.offset] - self.inverse_transform_func = lambda x: [x[0] - self.offset] - + self.transform_func = lambda x: { + name_mapping[1][i]: x[name_mapping[0][i]] + self.offset + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: x[name_mapping[1][i]] - self.offset + for i in range(len(name_mapping[1])) + } class LogitTransform(BijectiveTransform): """ @@ -209,9 +240,14 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: [1 / (1 + jnp.exp(-x[0]))] - self.inverse_transform_func = lambda x: [jnp.log(x[0] / (1 - x[0]))] - + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.log(x[name_mapping[0][i]] / (1 - x[name_mapping[0][i]])) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: 1 / (1 + jnp.exp(-x[name_mapping[1][i]])) + for i in range(len(name_mapping[1])) + } class ArcSineTransform(BijectiveTransform): """ @@ -229,9 +265,14 @@ def __init__( name_mapping: tuple[list[str], list[str]], ): super().__init__(name_mapping) - self.transform_func = lambda x: [jnp.arcsin(x[0])] - self.inverse_transform_func = lambda x: [jnp.sin(x[0])] - + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.arcsin(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.sin(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } # class PowerLawTransform(UnivariateTransform): # """ From 1e389bb1e19332576ece2d2911144a84e9dc16af Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 11:46:05 -0400 Subject: [PATCH 079/104] Add UniformComponentMassPrior --- src/jimgw/prior.py | 51 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 13ab622f..7880648d 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,8 +12,7 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - # PowerLawTransform, - # ParetoTransform, + RectangleToTriangleTransform, ) @@ -399,7 +398,53 @@ def __init__( @jaxtyped(typechecker=typechecker) -class UniformInComponentsChirpMassPrior(PowerLawPrior): +class UniformComponentMassPrior(SequentialTransformPrior): + """ + A prior in the range [xmin, xmax) for component masses which assumes the + component masses to be uniformly distributed. + """ + + xmin: float + xmax: float + + def __repr__(self): + return f"UniformComponentMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" + + def __init__(self, xmin: float, xmax: float, parameter_names: list[str]): + self.parameter_names = parameter_names + assert self.n_dim == 2, "UniformComponentMassPrior needs to be 2D distributions" + self.xmax = xmax + self.xmin = xmin + super().__init__( + CombinePrior( + [ + UniformPrior(xmin, xmax, ["x_1"]), + UniformPrior(xmin, xmax, ["x_2"]), + ] + ), + [ + ScaleTransform( + ( + ["x_1", "x_2"], + [f"x_1/({xmax-xmin})", f"x_2/({xmax-xmin})"], + ), + 1 / (xmax - xmin), + ), + RectangleToTriangleTransform( + ( + [ + f"{self.parameter_names[0]}/({xmax-xmin})", + f"{self.parameter_names[1]}/({xmax-xmin})", + ], + [self.parameter_names[0], self.parameter_names[1]], + ) + ), + ], + ) + + +@jaxtyped(typechecker=typechecker) +class UniformComponentChirpMassPrior(PowerLawPrior): """ A prior in the range [xmin, xmax) for chirp mass which assumes the component masses to be uniformly distributed. From 0252392aacf4fa53e18a7d5169e57dac6bcd15bb Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 12:26:25 -0400 Subject: [PATCH 080/104] prior system should work with dictionary function now --- src/jimgw/jim.py | 10 +++------- test/integration/test_GW150914.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 1a694172..81d56836 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -32,12 +32,8 @@ def __init__( self.likelihood = likelihood self.prior = prior if parameter_names is None: - print("No parameter names provided. Will try to trace the prior.") - parents = [] - trace_prior_parent(prior, parents) - parameter_names = [] - for parent in parents: - parameter_names.extend(parent.parameter_names) + print("No parameter names provided. Using prior names.") + parameter_names = prior.parameter_names self.parameter_names = parameter_names seed = kwargs.get("seed", 0) @@ -79,7 +75,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) - prior = self.prior.log_prob(named_params) + prior = self.prior.log_prob(named_params) return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index deb3fb98..4028164f 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -106,4 +106,4 @@ strategies=[Adam_optimizer, "default"], ) -jim.sample(jax.random.PRNGKey(42)) +# jim.sample(jax.random.PRNGKey(42)) From cb75bb9878036038e0dc8733bf51bddaa5a57f8a Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 12:31:03 -0400 Subject: [PATCH 081/104] update transform --- src/jimgw/transforms.py | 22 ++++++++++++++-------- test/integration/test_GW150914.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 13d66173..bfabc2ad 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -32,6 +32,10 @@ class NtoNTransform(Transform): transform_func: Callable[[dict[str, Float]], dict[str, Float]] + @property + def n_dim(self) -> int: + return len(self.name_mapping[0]) + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. @@ -50,11 +54,12 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: The log Jacobian determinant. """ x_copy = x.copy() - output_params = self.transform_func(x_copy) - jacobian = jax.jacfwd(self.transform_func)(x_copy) + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(int(jnp.sqrt(jacobian.size)), -1)) + jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) ) jax.tree.map( lambda key: x_copy.pop(key), @@ -115,11 +120,12 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: The original dictionary. """ y_copy = y.copy() - output_params = self.inverse_transform_func(y_copy) - jacobian = jax.jacfwd(self.inverse_transform_func)(y_copy) + transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) + output_params = self.inverse_transform_func(transform_params) + jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(int(jnp.sqrt(jacobian.size)), -1)) + jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) ) jax.tree.map( lambda key: y_copy.pop(key), @@ -241,11 +247,11 @@ def __init__( ): super().__init__(name_mapping) self.transform_func = lambda x: { - name_mapping[1][i]: jnp.log(x[name_mapping[0][i]] / (1 - x[name_mapping[0][i]])) + name_mapping[1][i]: 1 / (1 + jnp.exp(-x[name_mapping[0][i]])) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: 1 / (1 + jnp.exp(-x[name_mapping[1][i]])) + name_mapping[0][i]: jnp.log(x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]])) for i in range(len(name_mapping[1])) } diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 4028164f..deb3fb98 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -106,4 +106,4 @@ strategies=[Adam_optimizer, "default"], ) -# jim.sample(jax.random.PRNGKey(42)) +jim.sample(jax.random.PRNGKey(42)) From 561e628785e290ed4d07f49a1b3dd7010140457c Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 13:16:52 -0400 Subject: [PATCH 082/104] Move subclassing structure --- src/jimgw/jim.py | 29 ++++++++++---- src/jimgw/transforms.py | 86 +++++++++++++++-------------------------- 2 files changed, 54 insertions(+), 61 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 81d56836..08878293 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,6 +8,7 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior, trace_prior_parent +from jimgw.transforms import BijectiveTransform, NtoMTransform class Jim(object): @@ -19,22 +20,30 @@ class Jim(object): prior: Prior # Name of parameters to sample from - parameter_names: list[str] + sample_transforms: list[BijectiveTransform] + likelihood_transforms: list[NtoMTransform] sampler: Sampler def __init__( self, likelihood: LikelihoodBase, prior: Prior, - parameter_names: list[str] | None = None, + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], **kwargs, ): self.likelihood = likelihood self.prior = prior - if parameter_names is None: - print("No parameter names provided. Using prior names.") - parameter_names = prior.parameter_names - self.parameter_names = parameter_names + + self.sample_transforms = sample_transforms + self.likelihood_transforms = likelihood_transforms + + if len(sample_transforms) == 0: + print("No sample transforms provided. Using prior parameters as sampling parameters") + + if len(likelihood_transforms) == 0: + print("No likelihood transforms provided. Using prior parameters as likelihood parameters") + seed = kwargs.get("seed", 0) @@ -75,7 +84,13 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) - prior = self.prior.log_prob(named_params) + transform_jacobian = 0.0 + for transform in self.sample_transforms: + named_params, jacobian = transform.inverse(named_params) + transform_jacobian += jacobian + prior = self.prior.log_prob(named_params) + transform_jacobian + for transform in self.likelihood_transforms: + named_params = transform.forward(named_params) return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index bfabc2ad..fec18a79 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -28,18 +28,13 @@ def propagate_name(self, x: list[str]) -> list[str]: return list(input_set - from_set | to_set) -class NtoNTransform(Transform): +class NtoMTransform(Transform): transform_func: Callable[[dict[str, Float]], dict[str, Float]] - @property - def n_dim(self) -> int: - return len(self.name_mapping[0]) - - def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ - Transform the input x to transformed coordinate y and return the log Jacobian determinant. - This only works if the transform is a N -> N transform. + Push forward the input x to transformed coordinate y. Parameters ---------- @@ -50,17 +45,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: ------- y : dict[str, Float] The transformed dictionary. - log_det : Float - The log Jacobian determinant. """ x_copy = x.copy() - transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) - output_params = self.transform_func(transform_params) - jacobian = jax.jacfwd(self.transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) - ) + output_params = self.transform_func(x_copy) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -69,11 +56,21 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) - return x_copy, jacobian + return x_copy - def forward(self, x: dict[str, Float]) -> dict[str, Float]: + +class NtoNTransform(NtoMTransform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + @property + def n_dim(self) -> int: + return len(self.name_mapping[0]) + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ - Push forward the input x to transformed coordinate y. + Transform the input x to transformed coordinate y and return the log Jacobian determinant. + This only works if the transform is a N -> N transform. Parameters ---------- @@ -84,9 +81,15 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: ------- y : dict[str, Float] The transformed dictionary. + log_det : Float + The log Jacobian determinant. """ x_copy = x.copy() - output_params = self.transform_func(x_copy) + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -95,16 +98,13 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) - return x_copy + return x_copy, jacobian class BijectiveTransform(NtoNTransform): inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - return self.transform(x) - def inverse(self, y: dict[str, Float]) -> dict[str, Float]: """ Inverse transform the input y to original coordinate x. @@ -124,9 +124,7 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) - ) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], @@ -166,31 +164,6 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: return y_copy -class NtoMTransform(Transform): - - transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " m_dim"]] - - def __call__(self, x: dict[str, Float]) -> dict[str, Float]: - return self.forward(x) - - @abstractmethod - def forward(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Push forward the input x to transformed coordinate y. - - Parameters - ---------- - x : dict[str, Float] - The input dictionary. - - Returns - ------- - y : dict[str, Float] - The transformed dictionary. - """ - raise NotImplementedError - - class ScaleTransform(BijectiveTransform): scale: Float @@ -230,6 +203,7 @@ def __init__( for i in range(len(name_mapping[1])) } + class LogitTransform(BijectiveTransform): """ Logit transform following @@ -251,10 +225,13 @@ def __init__( for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.log(x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]])) + name_mapping[0][i]: jnp.log( + x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]]) + ) for i in range(len(name_mapping[1])) } + class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -280,6 +257,7 @@ def __init__( for i in range(len(name_mapping[1])) } + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 1aaf5ea91a6774ecc89314af3b953ac34c75539f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 13:59:44 -0400 Subject: [PATCH 083/104] Remove transformation --- src/jimgw/jim.py | 4 ++-- src/jimgw/transforms.py | 37 ++++++++----------------------------- 2 files changed, 10 insertions(+), 31 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 81d56836..03bcbc39 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from jimgw.base import LikelihoodBase -from jimgw.prior import Prior, trace_prior_parent +from jimgw.prior import Prior class Jim(object): @@ -75,7 +75,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) - prior = self.prior.log_prob(named_params) + prior = self.prior.log_prob(named_params) return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4a84d82f..7d1b5f3a 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -59,9 +59,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) - ) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -125,9 +123,7 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log( - jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)) - ) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], @@ -231,6 +227,7 @@ def __init__( for i in range(len(name_mapping[1])) } + class LogitTransform(BijectiveTransform): """ Logit transformation @@ -252,10 +249,13 @@ def __init__( for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.log(x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]])) + name_mapping[0][i]: jnp.log( + x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]]) + ) for i in range(len(name_mapping[1])) } + class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -281,6 +281,7 @@ def __init__( for i in range(len(name_mapping[1])) } + class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): """ Transform component masses to chirp mass and mass ratio. @@ -301,28 +302,6 @@ def __init__( self.inverse_transform_func = lambda x: Mc_q_to_m1_m2(x[0], x[1]) -class RectangleToTriangleTransform(BijectiveTransform): - """ - Transform a rectangle grid with bounds [0, 1] x [0, 1] to a triangle grid with vertices (0, 0), (1, 0), (0, 1), while preserving the area density. - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - self.transform_func = lambda x: [ - 1 - jnp.sqrt(1 - x[0]), - x[1] * jnp.sqrt(1 - x[0]), - ] - self.inverse_transform_func = lambda x: [1 - (1 - x[0]) ** 2, x[1] / (1 - x[0])] - - # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 5bfdda13161e954861774881701f49855ff7a267 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 14:01:02 -0400 Subject: [PATCH 084/104] Remove prior --- src/jimgw/prior.py | 47 ---------------------------------------------- 1 file changed, 47 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 7880648d..031a4133 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -12,7 +12,6 @@ ScaleTransform, OffsetTransform, ArcSineTransform, - RectangleToTriangleTransform, ) @@ -397,52 +396,6 @@ def __init__( ) -@jaxtyped(typechecker=typechecker) -class UniformComponentMassPrior(SequentialTransformPrior): - """ - A prior in the range [xmin, xmax) for component masses which assumes the - component masses to be uniformly distributed. - """ - - xmin: float - xmax: float - - def __repr__(self): - return f"UniformComponentMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" - - def __init__(self, xmin: float, xmax: float, parameter_names: list[str]): - self.parameter_names = parameter_names - assert self.n_dim == 2, "UniformComponentMassPrior needs to be 2D distributions" - self.xmax = xmax - self.xmin = xmin - super().__init__( - CombinePrior( - [ - UniformPrior(xmin, xmax, ["x_1"]), - UniformPrior(xmin, xmax, ["x_2"]), - ] - ), - [ - ScaleTransform( - ( - ["x_1", "x_2"], - [f"x_1/({xmax-xmin})", f"x_2/({xmax-xmin})"], - ), - 1 / (xmax - xmin), - ), - RectangleToTriangleTransform( - ( - [ - f"{self.parameter_names[0]}/({xmax-xmin})", - f"{self.parameter_names[1]}/({xmax-xmin})", - ], - [self.parameter_names[0], self.parameter_names[1]], - ) - ), - ], - ) - - @jaxtyped(typechecker=typechecker) class UniformComponentChirpMassPrior(PowerLawPrior): """ From 77a6ad1785bb46a5d5f9163d9741e79d83ebacb9 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 14:03:04 -0400 Subject: [PATCH 085/104] Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-class' into transform --- src/jimgw/jim.py | 32 ++++++++++--- src/jimgw/transforms.py | 101 +++++++++++----------------------------- 2 files changed, 53 insertions(+), 80 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 03bcbc39..3186c2c4 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,6 +8,7 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior +from jimgw.transforms import BijectiveTransform, NtoMTransform class Jim(object): @@ -19,22 +20,33 @@ class Jim(object): prior: Prior # Name of parameters to sample from - parameter_names: list[str] + sample_transforms: list[BijectiveTransform] + likelihood_transforms: list[NtoMTransform] sampler: Sampler def __init__( self, likelihood: LikelihoodBase, prior: Prior, - parameter_names: list[str] | None = None, + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], **kwargs, ): self.likelihood = likelihood self.prior = prior - if parameter_names is None: - print("No parameter names provided. Using prior names.") - parameter_names = prior.parameter_names - self.parameter_names = parameter_names + + self.sample_transforms = sample_transforms + self.likelihood_transforms = likelihood_transforms + + if len(sample_transforms) == 0: + print( + "No sample transforms provided. Using prior parameters as sampling parameters" + ) + + if len(likelihood_transforms) == 0: + print( + "No likelihood transforms provided. Using prior parameters as likelihood parameters" + ) seed = kwargs.get("seed", 0) @@ -75,7 +87,13 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = self.add_name(params) - prior = self.prior.log_prob(named_params) + transform_jacobian = 0.0 + for transform in self.sample_transforms: + named_params, jacobian = transform.inverse(named_params) + transform_jacobian += jacobian + prior = self.prior.log_prob(named_params) + transform_jacobian + for transform in self.likelihood_transforms: + named_params = transform.forward(named_params) return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 7d1b5f3a..8de02f40 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,11 +1,9 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Callable import jax import jax.numpy as jnp -from jaxtyping import Float, Array - -from jimgw.single_event.utils import m1_m2_to_Mc_q, Mc_q_to_m1_m2 +from jaxtyping import Float class Transform(ABC): @@ -29,18 +27,13 @@ def propagate_name(self, x: list[str]) -> list[str]: return list(input_set - from_set | to_set) -class NtoNTransform(Transform): +class NtoMTransform(Transform): transform_func: Callable[[dict[str, Float]], dict[str, Float]] - @property - def n_dim(self) -> int: - return len(self.name_mapping[0]) - - def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + def forward(self, x: dict[str, Float]) -> dict[str, Float]: """ - Transform the input x to transformed coordinate y and return the log Jacobian determinant. - This only works if the transform is a N -> N transform. + Push forward the input x to transformed coordinate y. Parameters ---------- @@ -51,15 +44,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: ------- y : dict[str, Float] The transformed dictionary. - log_det : Float - The log Jacobian determinant. """ x_copy = x.copy() - transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) - output_params = self.transform_func(transform_params) - jacobian = jax.jacfwd(self.transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + output_params = self.transform_func(x_copy) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -68,11 +55,21 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) - return x_copy, jacobian + return x_copy - def forward(self, x: dict[str, Float]) -> dict[str, Float]: + +class NtoNTransform(NtoMTransform): + + transform_func: Callable[[dict[str, Float]], dict[str, Float]] + + @property + def n_dim(self) -> int: + return len(self.name_mapping[0]) + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ - Push forward the input x to transformed coordinate y. + Transform the input x to transformed coordinate y and return the log Jacobian determinant. + This only works if the transform is a N -> N transform. Parameters ---------- @@ -83,9 +80,15 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: ------- y : dict[str, Float] The transformed dictionary. + log_det : Float + The log Jacobian determinant. """ x_copy = x.copy() - output_params = self.transform_func(x_copy) + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -94,16 +97,13 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) - return x_copy + return x_copy, jacobian class BijectiveTransform(NtoNTransform): inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def __call__(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: - return self.transform(x) - def inverse(self, y: dict[str, Float]) -> dict[str, Float]: """ Inverse transform the input y to original coordinate x. @@ -163,31 +163,6 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: return y_copy -class NtoMTransform(Transform): - - transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " m_dim"]] - - def __call__(self, x: dict[str, Float]) -> dict[str, Float]: - return self.forward(x) - - @abstractmethod - def forward(self, x: dict[str, Float]) -> dict[str, Float]: - """ - Push forward the input x to transformed coordinate y. - - Parameters - ---------- - x : dict[str, Float] - The input dictionary. - - Returns - ------- - y : dict[str, Float] - The transformed dictionary. - """ - raise NotImplementedError - - class ScaleTransform(BijectiveTransform): scale: Float @@ -230,7 +205,7 @@ def __init__( class LogitTransform(BijectiveTransform): """ - Logit transformation + Logit transform following Parameters ---------- @@ -282,26 +257,6 @@ def __init__( } -class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): - """ - Transform component masses to chirp mass and mass ratio. - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - assert name_mapping == (["m_1", "m_2"], ["M_c", "q"]) - super().__init__(name_mapping) - self.transform_func = lambda x: m1_m2_to_Mc_q(x[0], x[1]) - self.inverse_transform_func = lambda x: Mc_q_to_m1_m2(x[0], x[1]) - - # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 2401df8a8acf989305c158051b887cc5620b3218 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 31 Jul 2024 16:02:55 -0400 Subject: [PATCH 086/104] Add mass transform --- src/jimgw/single_event/utils.py | 40 +++++++++++++++++++ src/jimgw/transforms.py | 70 +++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 3d07bb57..aaf4b947 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -163,6 +163,46 @@ def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: return M, eta +def Mc_q_to_eta(M_c: Float, q: Float) -> Float: + """ + Transforming the chirp mass M_c and mass ratio q to the symmetric mass ratio eta. + + Parameters + ---------- + M_c : Float + Chirp mass. + q : Float + Mass ratio. + + Returns + ------- + eta : Float + Symmetric mass ratio. + """ + eta = q / (1 + q) ** 2 + return eta + + +def eta_to_q(eta: Float) -> Float: + """ + Transforming the symmetric mass ratio eta to the mass ratio q. + + Copied and modified from bilby/gw/conversion.py + + Parameters + ---------- + eta : Float + Symmetric mass ratio. + + Returns + ------- + q : Float + Mass ratio. + """ + temp = 1 / eta / 2 - 1 + return temp - (temp**2 - 1) ** 0.5 + + def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ Transforming the right ascension ra and declination dec to the polar angle diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 8de02f40..2e80e3c5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -5,6 +5,8 @@ import jax.numpy as jnp from jaxtyping import Float +from jimgw.single_event.utils import Mc_q_to_m1_m2, m1_m2_to_Mc_q, Mc_q_to_eta, eta_to_q + class Transform(ABC): """ @@ -257,6 +259,74 @@ def __init__( } +class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to component masses + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + m1, m2 = Mc_q_to_m1_m2(Mc, q) + return {name_mapping[1][0]: m1, name_mapping[1][1]: m2} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1 = x[name_mapping[1][0]] + m2 = x[name_mapping[1][1]] + Mc, q = m1_m2_to_Mc_q(m1, m2) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + +class ChirpMassMassRatioToChirpMassSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to chirp mass and symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + eta = Mc_q_to_eta(Mc, q) + return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta} + + self.transform_func = named_transform + + def named_inverse_transform(x): + Mc = x[name_mapping[1][0]] + eta = x[name_mapping[1][1]] + q = eta_to_q(Mc, eta) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From ff35a82f0c4db96f64507ef5bcf4b06a58738000 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 31 Jul 2024 16:24:48 -0400 Subject: [PATCH 087/104] Add Bound transforming transform --- src/jimgw/transforms.py | 112 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index fec18a79..57e787a5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -4,7 +4,8 @@ import jax import jax.numpy as jnp from chex import assert_rank -from jaxtyping import Float, Array +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped class Transform(ABC): @@ -258,6 +259,115 @@ def __init__( } +@jaxtyped(typechecker=typechecker) +class BoundToBound(BijectiveTransform): + + """ + Bound to bound transformation + """ + + original_lower_bound: Float[Array, " n_dim"] + original_upper_bound: Float[Array, " n_dim"] + target_lower_bound: Float[Array, " n_dim"] + target_upper_bound: Float[Array, " n_dim"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float[Array, " n_dim"], + original_upper_bound: Float[Array, " n_dim"], + target_lower_bound: Float[Array, " n_dim"], + target_upper_bound: Float[Array, " n_dim"], + ): + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + self.target_lower_bound = target_lower_bound + self.target_upper_bound = target_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: (x[name_mapping[0][i]] - self.original_lower_bound[i]) + * (self.target_upper_bound[i] - self.target_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) + + self.target_lower_bound[i] + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (x[name_mapping[1][i]] - self.target_lower_bound[i]) + * (self.original_upper_bound[i] - self.original_lower_bound[i]) + / (self.target_upper_bound[i] - self.target_lower_bound[i]) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + +class BoundToUnbound(BijectiveTransform): + """ + Bound to unbound transformation + """ + + original_lower_bound: Float + original_upper_bound: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, + original_upper_bound: Float, + ): + + def logit(x): + return jnp.log(x / (1 - x)) + + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: logit( + (x[name_mapping[0][i]] - self.original_lower_bound) + / (self.original_upper_bound - self.original_lower_bound) + ) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: ( + self.original_upper_bound - self.original_lower_bound + ) + / ( + 1 + + jnp.exp(-x[name_mapping[1][i]]) + ) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + +class SingleSidedUnboundTransform(BijectiveTransform): + """ + Unbound upper limit transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 47af9cfb2f70ac18e3547cd2f1be537f2e83aa57 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Wed, 31 Jul 2024 23:08:44 -0400 Subject: [PATCH 088/104] no nans now, but the code can be consolidate --- src/jimgw/jim.py | 13 +++++++++++-- src/jimgw/transforms.py | 5 +++-- test/integration/test_GW150914.py | 17 +++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 08878293..39867e9f 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -22,6 +22,7 @@ class Jim(object): # Name of parameters to sample from sample_transforms: list[BijectiveTransform] likelihood_transforms: list[NtoMTransform] + parameter_names: list[str] sampler: Sampler def __init__( @@ -37,9 +38,14 @@ def __init__( self.sample_transforms = sample_transforms self.likelihood_transforms = likelihood_transforms + self.parameter_names = prior.parameter_names if len(sample_transforms) == 0: print("No sample transforms provided. Using prior parameters as sampling parameters") + else: + print("Using sample transforms") + for transform in sample_transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) if len(likelihood_transforms) == 0: print("No likelihood transforms provided. Using prior parameters as likelihood parameters") @@ -91,12 +97,15 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - return self.likelihood.evaluate(named_params, data) + prior + named_params = jax.tree.map(lambda x:x[0], named_params) # This [0] should be consolidate + return self.likelihood.evaluate(named_params, data) + prior[0] # This prior [0] should be consolidate def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + for transform in self.sample_transforms: + initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[0] # This [0] should be consolidate self.Sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 57e787a5..b832ccdc 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -300,6 +300,7 @@ def __init__( for i in range(len(name_mapping[1])) } +@jaxtyped(typechecker=typechecker) class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -319,8 +320,8 @@ def logit(x): return jnp.log(x / (1 - x)) super().__init__(name_mapping) - self.original_lower_bound = original_lower_bound - self.original_upper_bound = original_upper_bound + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.original_upper_bound = jnp.atleast_1d(original_upper_bound) self.transform_func = lambda x: { name_mapping[1][i]: logit( diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index deb3fb98..d82e8d7a 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,6 +8,7 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -64,6 +65,21 @@ dec_prior, ] ) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) +] + likelihood = TransientLikelihoodFD( [H1, L1], waveform=RippleIMRPhenomD(), @@ -88,6 +104,7 @@ jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, From b5e06a6972264597a03eb7f1afd2ab42d1e1eabd Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 10:19:40 -0400 Subject: [PATCH 089/104] Solve conflict --- src/jimgw/transforms.py | 107 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 2e80e3c5..b405f066 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -3,7 +3,8 @@ import jax import jax.numpy as jnp -from jaxtyping import Float +from beartype import beartype as typechecker +from jaxtyping import Float, Array, jaxtyped from jimgw.single_event.utils import Mc_q_to_m1_m2, m1_m2_to_Mc_q, Mc_q_to_eta, eta_to_q @@ -259,6 +260,110 @@ def __init__( } +@jaxtyped(typechecker=typechecker) +class BoundToBound(BijectiveTransform): + """ + Bound to bound transformation + """ + + original_lower_bound: Float[Array, " n_dim"] + original_upper_bound: Float[Array, " n_dim"] + target_lower_bound: Float[Array, " n_dim"] + target_upper_bound: Float[Array, " n_dim"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float[Array, " n_dim"], + original_upper_bound: Float[Array, " n_dim"], + target_lower_bound: Float[Array, " n_dim"], + target_upper_bound: Float[Array, " n_dim"], + ): + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + self.target_lower_bound = target_lower_bound + self.target_upper_bound = target_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: (x[name_mapping[0][i]] - self.original_lower_bound[i]) + * (self.target_upper_bound[i] - self.target_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) + + self.target_lower_bound[i] + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (x[name_mapping[1][i]] - self.target_lower_bound[i]) + * (self.original_upper_bound[i] - self.original_lower_bound[i]) + / (self.target_upper_bound[i] - self.target_lower_bound[i]) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + + +class BoundToUnbound(BijectiveTransform): + """ + Bound to unbound transformation + """ + + original_lower_bound: Float + original_upper_bound: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, + original_upper_bound: Float, + ): + + def logit(x): + return jnp.log(x / (1 - x)) + + super().__init__(name_mapping) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound + + self.transform_func = lambda x: { + name_mapping[1][i]: logit( + (x[name_mapping[0][i]] - self.original_lower_bound) + / (self.original_upper_bound - self.original_lower_bound) + ) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) + / (1 + jnp.exp(-x[name_mapping[1][i]])) + + self.original_lower_bound[i] + for i in range(len(name_mapping[1])) + } + + +class SingleSidedUnboundTransform(BijectiveTransform): + """ + Unbound upper limit transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): """ Transform chirp mass and mass ratio to component masses From 8e5b326c295f94f2390ab40427dd63455f2839b0 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 11:06:30 -0400 Subject: [PATCH 090/104] Add sky position transform --- src/jimgw/single_event/utils.py | 143 +++++++++++++++++++++++++------- src/jimgw/transforms.py | 47 ++++++++++- 2 files changed, 155 insertions(+), 35 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index aaf4b947..2373c5be 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -203,32 +203,6 @@ def eta_to_q(eta: Float) -> Float: return temp - (temp**2 - 1) ** 0.5 -def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: - """ - Transforming the right ascension ra and declination dec to the polar angle - theta and azimuthal angle phi. - - Parameters - ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. - - Returns - ------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. - """ - phi = ra - gmst - theta = jnp.pi / 2 - dec - return theta, phi - - def euler_rotation(delta_x: Float[Array, " 3"]): """ Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x @@ -239,11 +213,10 @@ def euler_rotation(delta_x: Float[Array, " 3"]): Copied and modified from bilby-cython/geometry.pyx """ - norm = jnp.power( - delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5 - ) + norm = jnp.linalg.vector_norm(delta_x) + cos_beta = delta_x[2] / norm - sin_beta = jnp.power(1 - cos_beta**2, 0.5) + sin_beta = jnp.sqrt(1 - cos_beta**2) alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0]) gamma = jnp.atan2(delta_x[1], delta_x[0]) @@ -318,7 +291,7 @@ def zenith_azimuth_to_theta_phi( + rotation[0][2] * cos_zenith, ) + 2 * jnp.pi, - (2 * jnp.pi), + 2 * jnp.pi, ) return theta, phi @@ -345,6 +318,7 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F """ ra = phi + gmst dec = jnp.pi / 2 - theta + ra = ra % (2 * jnp.pi) return ra, dec @@ -376,10 +350,115 @@ def zenith_azimuth_to_ra_dec( """ theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) - ra = ra % (2 * jnp.pi) return ra, dec +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + phi = ra - gmst + theta = jnp.pi / 2 - dec + phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) + return theta, phi + + +def theta_phi_to_zenith_azimuth( + theta: Float, phi: Float, delta_x: Float[Array, " 3"] +) -> tuple[Float, Float]: + """ + Transforming the polar angle and azimuthal angle to the zenith angle and azimuthal angle. + + Parameters + ---------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + delta_x : Float + The vector pointing from the first detector to the second detector. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + """ + sin_theta = jnp.sin(theta) + cos_theta = jnp.cos(theta) + sin_phi = jnp.sin(phi) + cos_phi = jnp.cos(phi) + + rotation = euler_rotation(delta_x) + rotation = jnp.linalg.inv(rotation) + + zenith = jnp.acos( + rotation[2][0] * sin_theta * cos_phi + + rotation[2][1] * sin_theta * sin_phi + + rotation[2][2] * cos_theta + ) + azimuth = jnp.fmod( + jnp.atan2( + rotation[1][0] * sin_theta * cos_phi + + rotation[1][1] * sin_theta * sin_phi + + rotation[1][2] * cos_theta, + rotation[0][0] * sin_theta * cos_phi + + rotation[0][1] * sin_theta * sin_phi + + rotation[0][2] * cos_theta, + ) + + 2 * jnp.pi, + 2 * jnp.pi, + ) + return zenith, azimuth + + +def ra_dec_to_zenith_azimuth( + ra: Float, dec: Float, gmst: Float, delta_x: Float[Array, " 3"] +) -> tuple[Float, Float]: + """ + Transforming the right ascension and declination to the zenith angle and azimuthal angle. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + delta_x : Float + The vector pointing from the first detector to the second detector. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + """ + theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) + zenith, azimuth = theta_phi_to_zenith_azimuth(theta, phi, delta_x) + return zenith, azimuth + + def log_i0(x: Float[Array, " n"]) -> Float[Array, " n"]: """ A numerically stable method to evaluate log of diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 5d0d1643..d644c8e1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -6,7 +6,14 @@ from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped -from jimgw.single_event.utils import Mc_q_to_m1_m2, m1_m2_to_Mc_q, Mc_q_to_eta, eta_to_q +from jimgw.single_event.utils import ( + Mc_q_to_m1_m2, + m1_m2_to_Mc_q, + Mc_q_to_eta, + eta_to_q, + ra_dec_to_zenith_azimuth, + zenith_azimuth_to_ra_dec, +) class Transform(ABC): @@ -300,7 +307,7 @@ def __init__( for i in range(len(name_mapping[1])) } - + class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -315,7 +322,7 @@ def __init__( original_lower_bound: Float, original_upper_bound: Float, ): - + def logit(x): return jnp.log(x / (1 - x)) @@ -432,6 +439,40 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): + """ + Transform sky frame to detector frame sky position + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + ra = x[name_mapping[0][0]] + dec = x[name_mapping[0][1]] + zenith, azimuth = ra_dec_to_zenith_azimuth(ra, dec) + return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth} + + self.transform_func = named_transform + + def named_inverse_transform(x): + zenith = x[name_mapping[1][0]] + azimuth = x[name_mapping[1][1]] + ra, dec = zenith_azimuth_to_ra_dec(zenith, azimuth) + return {name_mapping[0][0]: ra, name_mapping[0][1]: dec} + + self.inverse_transform_func = named_inverse_transform + + # class PowerLawTransform(UnivariateTransform): # """ # PowerLaw transformation From 10d51b2b05723b4e0eb3e9b0e3bf2342b841b535 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 11:21:06 -0400 Subject: [PATCH 091/104] Modify sky position transform --- src/jimgw/single_event/utils.py | 64 ++++----------------------------- src/jimgw/transforms.py | 19 ++++++++-- 2 files changed, 23 insertions(+), 60 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 2373c5be..705857fe 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -245,8 +245,8 @@ def euler_rotation(delta_x: Float[Array, " 3"]): return rotation -def zenith_azimuth_to_theta_phi( - zenith: Float, azimuth: Float, delta_x: Float[Array, " 3"] +def angle_rotation( + zenith: Float, azimuth: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame. @@ -274,8 +274,6 @@ def zenith_azimuth_to_theta_phi( sin_zenith = jnp.sin(zenith) cos_zenith = jnp.cos(zenith) - rotation = euler_rotation(delta_x) - theta = jnp.acos( rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth @@ -323,7 +321,7 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F def zenith_azimuth_to_ra_dec( - zenith: Float, azimuth: Float, gmst: Float, delta_x: Float[Array, " 3"] + zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. @@ -348,7 +346,7 @@ def zenith_azimuth_to_ra_dec( dec : Float Declination. """ - theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + theta, phi = angle_rotation(zenith, azimuth, rotation) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) return ra, dec @@ -380,58 +378,8 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa return theta, phi -def theta_phi_to_zenith_azimuth( - theta: Float, phi: Float, delta_x: Float[Array, " 3"] -) -> tuple[Float, Float]: - """ - Transforming the polar angle and azimuthal angle to the zenith angle and azimuthal angle. - - Parameters - ---------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. - delta_x : Float - The vector pointing from the first detector to the second detector. - - Returns - ------- - zenith : Float - Zenith angle. - azimuth : Float - Azimuthal angle. - """ - sin_theta = jnp.sin(theta) - cos_theta = jnp.cos(theta) - sin_phi = jnp.sin(phi) - cos_phi = jnp.cos(phi) - - rotation = euler_rotation(delta_x) - rotation = jnp.linalg.inv(rotation) - - zenith = jnp.acos( - rotation[2][0] * sin_theta * cos_phi - + rotation[2][1] * sin_theta * sin_phi - + rotation[2][2] * cos_theta - ) - azimuth = jnp.fmod( - jnp.atan2( - rotation[1][0] * sin_theta * cos_phi - + rotation[1][1] * sin_theta * sin_phi - + rotation[1][2] * cos_theta, - rotation[0][0] * sin_theta * cos_phi - + rotation[0][1] * sin_theta * sin_phi - + rotation[0][2] * cos_theta, - ) - + 2 * jnp.pi, - 2 * jnp.pi, - ) - return zenith, azimuth - - def ra_dec_to_zenith_azimuth( - ra: Float, dec: Float, gmst: Float, delta_x: Float[Array, " 3"] + ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: """ Transforming the right ascension and declination to the zenith angle and azimuthal angle. @@ -455,7 +403,7 @@ def ra_dec_to_zenith_azimuth( Azimuthal angle. """ theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) - zenith, azimuth = theta_phi_to_zenith_azimuth(theta, phi, delta_x) + zenith, azimuth = angle_rotation(theta, phi, rotation) return zenith, azimuth diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index d644c8e1..5a889569 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -13,6 +13,7 @@ eta_to_q, ra_dec_to_zenith_azimuth, zenith_azimuth_to_ra_dec, + euler_rotation, ) @@ -450,16 +451,28 @@ class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): """ + gmst: Float + rotation: Float[Array, " 3 3"] + rotation_inv: Float[Array, " 3 3"] + def __init__( self, name_mapping: tuple[list[str], list[str]], + gmst: Float, + delta_x: Float, ): super().__init__(name_mapping) + self.gmst = gmst + self.rotation = euler_rotation(delta_x) + self.rotation_inv = jnp.linalg.inv(self.rotation) + def named_transform(x): ra = x[name_mapping[0][0]] dec = x[name_mapping[0][1]] - zenith, azimuth = ra_dec_to_zenith_azimuth(ra, dec) + zenith, azimuth = ra_dec_to_zenith_azimuth( + ra, dec, self.gmst, self.rotation + ) return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth} self.transform_func = named_transform @@ -467,7 +480,9 @@ def named_transform(x): def named_inverse_transform(x): zenith = x[name_mapping[1][0]] azimuth = x[name_mapping[1][1]] - ra, dec = zenith_azimuth_to_ra_dec(zenith, azimuth) + ra, dec = zenith_azimuth_to_ra_dec( + zenith, azimuth, self.gmst, self.rotation_inv + ) return {name_mapping[0][0]: ra, name_mapping[0][1]: dec} self.inverse_transform_func = named_inverse_transform From 8368d00bd9a5d3a41e670036bdcfcbeec07dfef3 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 11:55:30 -0400 Subject: [PATCH 092/104] Change util func name --- src/jimgw/single_event/utils.py | 2 +- src/jimgw/transforms.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 705857fe..8721c176 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -163,7 +163,7 @@ def m1_m2_to_M_eta(m1: Float, m2: Float) -> tuple[Float, Float]: return M, eta -def Mc_q_to_eta(M_c: Float, q: Float) -> Float: +def q_to_eta(q: Float) -> Float: """ Transforming the chirp mass M_c and mass ratio q to the symmetric mass ratio eta. diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 5a889569..165432eb 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -9,7 +9,7 @@ from jimgw.single_event.utils import ( Mc_q_to_m1_m2, m1_m2_to_Mc_q, - Mc_q_to_eta, + q_to_eta, eta_to_q, ra_dec_to_zenith_azimuth, zenith_azimuth_to_ra_dec, @@ -426,7 +426,7 @@ def __init__( def named_transform(x): Mc = x[name_mapping[0][0]] q = x[name_mapping[0][1]] - eta = Mc_q_to_eta(Mc, q) + eta = q_to_eta(q) return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta} self.transform_func = named_transform From 1cb0d11c2f58c8413cc34f823ca7943c0f820f59 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 11:59:38 -0400 Subject: [PATCH 093/104] Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-class' into transform" This reverts commit b4f60527bf7ec4e002698e32458b7da659a88da5, reversing changes made to 8368d00bd9a5d3a41e670036bdcfcbeec07dfef3. --- src/jimgw/jim.py | 24 ++--- src/jimgw/transforms.py | 147 +++++++++++++++++++++++++++--- test/integration/test_GW150914.py | 17 ---- 3 files changed, 142 insertions(+), 46 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 39867e9f..3186c2c4 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from jimgw.base import LikelihoodBase -from jimgw.prior import Prior, trace_prior_parent +from jimgw.prior import Prior from jimgw.transforms import BijectiveTransform, NtoMTransform @@ -22,7 +22,6 @@ class Jim(object): # Name of parameters to sample from sample_transforms: list[BijectiveTransform] likelihood_transforms: list[NtoMTransform] - parameter_names: list[str] sampler: Sampler def __init__( @@ -38,18 +37,16 @@ def __init__( self.sample_transforms = sample_transforms self.likelihood_transforms = likelihood_transforms - self.parameter_names = prior.parameter_names if len(sample_transforms) == 0: - print("No sample transforms provided. Using prior parameters as sampling parameters") - else: - print("Using sample transforms") - for transform in sample_transforms: - self.parameter_names = transform.propagate_name(self.parameter_names) + print( + "No sample transforms provided. Using prior parameters as sampling parameters" + ) if len(likelihood_transforms) == 0: - print("No likelihood transforms provided. Using prior parameters as likelihood parameters") - + print( + "No likelihood transforms provided. Using prior parameters as likelihood parameters" + ) seed = kwargs.get("seed", 0) @@ -97,15 +94,12 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - named_params = jax.tree.map(lambda x:x[0], named_params) # This [0] should be consolidate - return self.likelihood.evaluate(named_params, data) + prior[0] # This prior [0] should be consolidate + return self.likelihood.evaluate(named_params, data) + prior def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) - for transform in self.sample_transforms: - initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[0] # This [0] should be consolidate + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T self.Sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index b832ccdc..165432eb 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -1,12 +1,21 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Callable import jax import jax.numpy as jnp -from chex import assert_rank from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped +from jimgw.single_event.utils import ( + Mc_q_to_m1_m2, + m1_m2_to_Mc_q, + q_to_eta, + eta_to_q, + ra_dec_to_zenith_azimuth, + zenith_azimuth_to_ra_dec, + euler_rotation, +) + class Transform(ABC): """ @@ -261,7 +270,6 @@ def __init__( @jaxtyped(typechecker=typechecker) class BoundToBound(BijectiveTransform): - """ Bound to bound transformation """ @@ -300,7 +308,7 @@ def __init__( for i in range(len(name_mapping[1])) } -@jaxtyped(typechecker=typechecker) + class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -315,13 +323,13 @@ def __init__( original_lower_bound: Float, original_upper_bound: Float, ): - + def logit(x): return jnp.log(x / (1 - x)) super().__init__(name_mapping) - self.original_lower_bound = jnp.atleast_1d(original_lower_bound) - self.original_upper_bound = jnp.atleast_1d(original_upper_bound) + self.original_lower_bound = original_lower_bound + self.original_upper_bound = original_upper_bound self.transform_func = lambda x: { name_mapping[1][i]: logit( @@ -331,17 +339,13 @@ def logit(x): for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: ( - self.original_upper_bound - self.original_lower_bound - ) - / ( - 1 - + jnp.exp(-x[name_mapping[1][i]]) - ) + name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) + / (1 + jnp.exp(-x[name_mapping[1][i]])) + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } + class SingleSidedUnboundTransform(BijectiveTransform): """ Unbound upper limit transformation @@ -368,6 +372,121 @@ def __init__( } +class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to component masses + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + m1, m2 = Mc_q_to_m1_m2(Mc, q) + return {name_mapping[1][0]: m1, name_mapping[1][1]: m2} + + self.transform_func = named_transform + + def named_inverse_transform(x): + m1 = x[name_mapping[1][0]] + m2 = x[name_mapping[1][1]] + Mc, q = m1_m2_to_Mc_q(m1, m2) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + +class ChirpMassMassRatioToChirpMassSymmetricMassRatioTransform(BijectiveTransform): + """ + Transform chirp mass and mass ratio to chirp mass and symmetric mass ratio + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + + def named_transform(x): + Mc = x[name_mapping[0][0]] + q = x[name_mapping[0][1]] + eta = q_to_eta(q) + return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta} + + self.transform_func = named_transform + + def named_inverse_transform(x): + Mc = x[name_mapping[1][0]] + eta = x[name_mapping[1][1]] + q = eta_to_q(Mc, eta) + return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} + + self.inverse_transform_func = named_inverse_transform + + +class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): + """ + Transform sky frame to detector frame sky position + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + rotation: Float[Array, " 3 3"] + rotation_inv: Float[Array, " 3 3"] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gmst: Float, + delta_x: Float, + ): + super().__init__(name_mapping) + + self.gmst = gmst + self.rotation = euler_rotation(delta_x) + self.rotation_inv = jnp.linalg.inv(self.rotation) + + def named_transform(x): + ra = x[name_mapping[0][0]] + dec = x[name_mapping[0][1]] + zenith, azimuth = ra_dec_to_zenith_azimuth( + ra, dec, self.gmst, self.rotation + ) + return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth} + + self.transform_func = named_transform + + def named_inverse_transform(x): + zenith = x[name_mapping[1][0]] + azimuth = x[name_mapping[1][1]] + ra, dec = zenith_azimuth_to_ra_dec( + zenith, azimuth, self.gmst, self.rotation_inv + ) + return {name_mapping[0][0]: ra, name_mapping[0][1]: dec} + + self.inverse_transform_func = named_inverse_transform + # class PowerLawTransform(UnivariateTransform): # """ diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index d82e8d7a..deb3fb98 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,7 +8,6 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD -from jimgw.transforms import BoundToUnbound from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -65,21 +64,6 @@ dec_prior, ] ) - -sample_transforms = [ - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25), - BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), - BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), - BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2), - BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) -] - likelihood = TransientLikelihoodFD( [H1, L1], waveform=RippleIMRPhenomD(), @@ -104,7 +88,6 @@ jim = Jim( likelihood, prior, - sample_transforms=sample_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, From 78e93c502ed24b77f74f2cbb8b662ca54f3c4d56 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 12:03:00 -0400 Subject: [PATCH 094/104] Merge --- src/jimgw/jim.py | 19 +++++++++++++++++-- src/jimgw/transforms.py | 5 +++-- test/integration/test_GW150914.py | 17 +++++++++++++++++ 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 3186c2c4..2a9cb952 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -22,6 +22,7 @@ class Jim(object): # Name of parameters to sample from sample_transforms: list[BijectiveTransform] likelihood_transforms: list[NtoMTransform] + parameter_names: list[str] sampler: Sampler def __init__( @@ -37,11 +38,16 @@ def __init__( self.sample_transforms = sample_transforms self.likelihood_transforms = likelihood_transforms + self.parameter_names = prior.parameter_names if len(sample_transforms) == 0: print( "No sample transforms provided. Using prior parameters as sampling parameters" ) + else: + print("Using sample transforms") + for transform in sample_transforms: + self.parameter_names = transform.propagate_name(self.parameter_names) if len(likelihood_transforms) == 0: print( @@ -94,12 +100,21 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - return self.likelihood.evaluate(named_params, data) + prior + named_params = jax.tree.map( + lambda x: x[0], named_params + ) # This [0] should be consolidate + return ( + self.likelihood.evaluate(named_params, data) + prior[0] + ) # This prior [0] should be consolidate def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + for transform in self.sample_transforms: + initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[ + 0 + ] # This [0] should be consolidate self.Sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 165432eb..10c42c33 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -309,6 +309,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class BoundToUnbound(BijectiveTransform): """ Bound to unbound transformation @@ -328,8 +329,8 @@ def logit(x): return jnp.log(x / (1 - x)) super().__init__(name_mapping) - self.original_lower_bound = original_lower_bound - self.original_upper_bound = original_upper_bound + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.original_upper_bound = jnp.atleast_1d(original_upper_bound) self.transform_func = lambda x: { name_mapping[1][i]: logit( diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index deb3fb98..d82e8d7a 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,6 +8,7 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -64,6 +65,21 @@ dec_prior, ] ) + +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) +] + likelihood = TransientLikelihoodFD( [H1, L1], waveform=RippleIMRPhenomD(), @@ -88,6 +104,7 @@ jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, From 080bd8bb36b97f4a87ec2cbd7d84900e42d9b9e8 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 12:51:35 -0400 Subject: [PATCH 095/104] Modify integration test --- src/jimgw/transforms.py | 21 ++++----------------- test/integration/test_GW150914.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 10c42c33..3fc3cd1a 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -407,9 +407,9 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform -class ChirpMassMassRatioToChirpMassSymmetricMassRatioTransform(BijectiveTransform): +class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): """ - Transform chirp mass and mass ratio to chirp mass and symmetric mass ratio + Transform mass ratio to symmetric mass ratio Parameters ---------- @@ -424,21 +424,8 @@ def __init__( ): super().__init__(name_mapping) - def named_transform(x): - Mc = x[name_mapping[0][0]] - q = x[name_mapping[0][1]] - eta = q_to_eta(q) - return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta} - - self.transform_func = named_transform - - def named_inverse_transform(x): - Mc = x[name_mapping[1][0]] - eta = x[name_mapping[1][1]] - q = eta_to_q(Mc, eta) - return {name_mapping[0][0]: Mc, name_mapping[0][1]: q} - - self.inverse_transform_func = named_inverse_transform + self.transform_func = lambda x: {name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]])} + self.inverse_transform_func = lambda x: {name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]])} class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index d82e8d7a..4cee35e9 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -8,7 +8,7 @@ from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD -from jimgw.transforms import BoundToUnbound +from jimgw.transforms import BoundToUnbound, MassRatioToSymmetricMassRatioTransform from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -34,10 +34,10 @@ L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -eta_prior = UniformPrior( +q_prior = UniformPrior( 0.125, - 0.25, - parameter_names=["eta"], # Need name transformation in likelihood to work + 1, + parameter_names=["q"], ) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) @@ -45,15 +45,15 @@ dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -iota_prior = CosinePrior(parameter_names=["iota"]) +iota_prior = SinePrior(parameter_names=["iota"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -dec_prior = SinePrior(parameter_names=["dec"]) +dec_prior = CosinePrior(parameter_names=["dec"]) prior = CombinePrior( [ Mc_prior, - eta_prior, + q_prior, s1z_prior, s2z_prior, dL_prior, @@ -68,7 +68,7 @@ sample_transforms = [ BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), @@ -80,6 +80,10 @@ BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) ] +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + likelihood = TransientLikelihoodFD( [H1, L1], waveform=RippleIMRPhenomD(), @@ -105,6 +109,7 @@ likelihood, prior, sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, n_local_steps=5, From 4058d327418b3dbd9b742cb7af212a76eac8162f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 12:56:23 -0400 Subject: [PATCH 096/104] Reformat --- src/jimgw/transforms.py | 8 ++++++-- test/integration/test_GW150914.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 3fc3cd1a..0326b28c 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -424,8 +424,12 @@ def __init__( ): super().__init__(name_mapping) - self.transform_func = lambda x: {name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]])} - self.inverse_transform_func = lambda x: {name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]])} + self.transform_func = lambda x: { + name_mapping[1][0]: q_to_eta(x[name_mapping[0][0]]) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][0]: eta_to_q(x[name_mapping[1][0]]) + } class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 4cee35e9..753eb6b1 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -36,7 +36,7 @@ Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) q_prior = UniformPrior( 0.125, - 1, + 1., parameter_names=["q"], ) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) @@ -68,7 +68,7 @@ sample_transforms = [ BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), From cdf771d1f3225045607601353a3f9e7ef3cfac54 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 13:03:35 -0400 Subject: [PATCH 097/104] Add typecheck --- src/jimgw/transforms.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 0326b28c..1bb17c18 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -174,6 +174,7 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: return y_copy +@jaxtyped(typechecker=typechecker) class ScaleTransform(BijectiveTransform): scale: Float @@ -194,6 +195,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class OffsetTransform(BijectiveTransform): offset: Float @@ -214,6 +216,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class LogitTransform(BijectiveTransform): """ Logit transform following @@ -242,6 +245,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class ArcSineTransform(BijectiveTransform): """ ArcSine transformation @@ -347,6 +351,7 @@ def logit(x): } +@jaxtyped(typechecker=typechecker) class SingleSidedUnboundTransform(BijectiveTransform): """ Unbound upper limit transformation @@ -373,6 +378,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform): """ Transform chirp mass and mass ratio to component masses @@ -407,6 +413,7 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): """ Transform mass ratio to symmetric mass ratio @@ -432,6 +439,7 @@ def __init__( } +@jaxtyped(typechecker=typechecker) class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): """ Transform sky frame to detector frame sky position From 02c5650a064fa749c2a63a9e2c885c230ca05523 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 13:21:35 -0400 Subject: [PATCH 098/104] minor typo --- src/jimgw/jim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2a9cb952..19f74606 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -70,7 +70,7 @@ def __init__( self.prior.n_dim, num_layers, hidden_size, num_bins, subkey ) - self.Sampler = Sampler( + self.sampler = Sampler( self.prior.n_dim, rng_key, None, # type: ignore From ce7ac34d2e8094ab1cfaa31fafa657c1887fcf45 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 13:38:40 -0400 Subject: [PATCH 099/104] Rename sampler --- src/jimgw/jim.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 19f74606..f21aff42 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -109,13 +109,13 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: - initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) + initial_guess_named = self.prior.sample(key, self.sampler.n_chains) for transform in self.sample_transforms: initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[ 0 ] # This [0] should be consolidate - self.Sampler.sample(initial_guess, None) # type: ignore + self.sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( self, @@ -148,8 +148,8 @@ def print_summary(self, transform: bool = True): """ - train_summary = self.Sampler.get_sampler_state(training=True) - production_summary = self.Sampler.get_sampler_state(training=False) + train_summary = self.sampler.get_sampler_state(training=True) + production_summary = self.sampler.get_sampler_state(training=False) training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T training_chain = self.prior.add_name(training_chain) @@ -215,9 +215,9 @@ def get_samples(self, training: bool = False) -> dict: """ if training: - chains = self.Sampler.get_sampler_state(training=True)["chains"] + chains = self.sampler.get_sampler_state(training=True)["chains"] else: - chains = self.Sampler.get_sampler_state(training=False)["chains"] + chains = self.sampler.get_sampler_state(training=False)["chains"] chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1))) return chains From 7d44aa4ba9763a31a8310b40b3eac29d47e47893 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 13:40:53 -0400 Subject: [PATCH 100/104] Fix test --- test/integration/test_GW150914.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 753eb6b1..6193032a 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -74,10 +74,10 @@ BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi) + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) ] likelihood_transforms = [ From e9288c8116b2209cdc1de477bac5fc0742cc9482 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 14:43:14 -0400 Subject: [PATCH 101/104] Fix BoundToUnbound transform --- src/jimgw/jim.py | 11 +++-------- src/jimgw/transforms.py | 6 +++--- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index f21aff42..4ccf0451 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -100,21 +100,16 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): prior = self.prior.log_prob(named_params) + transform_jacobian for transform in self.likelihood_transforms: named_params = transform.forward(named_params) - named_params = jax.tree.map( - lambda x: x[0], named_params - ) # This [0] should be consolidate return ( - self.likelihood.evaluate(named_params, data) + prior[0] - ) # This prior [0] should be consolidate + self.likelihood.evaluate(named_params, data) + prior + ) def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.sampler.n_chains) for transform in self.sample_transforms: initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[ - 0 - ] # This [0] should be consolidate + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T self.sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 1bb17c18..c6c4332f 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -338,13 +338,13 @@ def logit(x): self.transform_func = lambda x: { name_mapping[1][i]: logit( - (x[name_mapping[0][i]] - self.original_lower_bound) - / (self.original_upper_bound - self.original_lower_bound) + (x[name_mapping[0][i]] - self.original_lower_bound[i]) + / (self.original_upper_bound[i] - self.original_lower_bound[i]) ) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound) + name_mapping[0][i]: (self.original_upper_bound[i] - self.original_lower_bound[i]) / (1 + jnp.exp(-x[name_mapping[1][i]])) + self.original_lower_bound[i] for i in range(len(name_mapping[1])) From fe500a792a04a4409874b0564392643c6678033b Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 15:19:06 -0400 Subject: [PATCH 102/104] Use ifos list --- test/integration/test_GW150914.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 6193032a..9adda574 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -28,10 +28,10 @@ fmin = 20.0 fmax = 1024.0 -ifos = ["H1", "L1"] +ifos = [H1, L1] -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) q_prior = UniformPrior( @@ -85,7 +85,7 @@ ] likelihood = TransientLikelihoodFD( - [H1, L1], + ifos, waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, From 0d28520440f95697a36fa166841e83cd3e4b18aa Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 17:13:01 -0400 Subject: [PATCH 103/104] Fix jim summary and get_samples --- src/jimgw/jim.py | 64 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 4ccf0451..91a358a6 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -137,7 +137,7 @@ def negative_posterior(x: Float[Array, " n_dim"]): best_fit = optimizer.get_result()[0] return best_fit - def print_summary(self, transform: bool = True): + def print_summary(self): """ Generate summary of the run @@ -146,19 +146,45 @@ def print_summary(self, transform: bool = True): train_summary = self.sampler.get_sampler_state(training=True) production_summary = self.sampler.get_sampler_state(training=False) - training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T - training_chain = self.prior.add_name(training_chain) - if transform: - training_chain = self.prior.transform(training_chain) + training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(training_chain[0]) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in training_chain[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + training_chain = transformed_chain + else: + training_chain = self.add_name(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] training_loss = train_summary["loss_vals"] - production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T - production_chain = self.prior.add_name(production_chain) - if transform: - production_chain = self.prior.transform(production_chain) + production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(production_chain[0]) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in production_chain[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + production_chain = transformed_chain + else: + production_chain = self.add_name(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -214,8 +240,24 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.sampler.get_sampler_state(training=False)["chains"] - chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1))) - return chains + # Need rewrite to output chains instead of flattened samples + chains = chains.reshape(-1, len(self.parameter_names)).T + if self.sample_transforms: + transformed_chain = {} + named_sample = self.add_name(chains[0]) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key] = [value] + for sample in chains[1:]: + named_sample = self.add_name(sample) + for transform in self.sample_transforms: + named_sample = transform.inverse(named_sample) + for key, value in named_sample.items(): + transformed_chain[key].append(value) + return transformed_chain + else: + return self.add_name(chains) def plot(self): pass From f7e3fe882c44914eda36b12fe6a633d153dd4c0a Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 1 Aug 2024 17:30:15 -0400 Subject: [PATCH 104/104] Fix jim output functions --- src/jimgw/jim.py | 24 ++++++++++++++---------- src/jimgw/transforms.py | 8 ++++---- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 91a358a6..74f65efc 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -151,13 +151,13 @@ def print_summary(self): transformed_chain = {} named_sample = self.add_name(training_chain[0]) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key] = [value] for sample in training_chain[1:]: named_sample = self.add_name(sample) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key].append(value) training_chain = transformed_chain @@ -173,13 +173,13 @@ def print_summary(self): transformed_chain = {} named_sample = self.add_name(production_chain[0]) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key] = [value] for sample in production_chain[1:]: named_sample = self.add_name(sample) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key].append(value) production_chain = transformed_chain @@ -192,7 +192,7 @@ def print_summary(self): print("Training summary") print("=" * 10) for key, value in training_chain.items(): - print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") + print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}") print( f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" ) @@ -209,7 +209,7 @@ def print_summary(self): print("Production summary") print("=" * 10) for key, value in production_chain.items(): - print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") + print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}") print( f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" ) @@ -246,18 +246,22 @@ def get_samples(self, training: bool = False) -> dict: transformed_chain = {} named_sample = self.add_name(chains[0]) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key] = [value] for sample in chains[1:]: named_sample = self.add_name(sample) for transform in self.sample_transforms: - named_sample = transform.inverse(named_sample) + named_sample = transform.backward(named_sample) for key, value in named_sample.items(): transformed_chain[key].append(value) - return transformed_chain + output = transformed_chain else: - return self.add_name(chains) + output = self.add_name(chains) + + for key in output.keys(): + output[key] = jnp.array(output[key]) + return output def plot(self): pass diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index c6c4332f..9aea6503 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -115,7 +115,7 @@ class BijectiveTransform(NtoNTransform): inverse_transform_func: Callable[[dict[str, Float]], dict[str, Float]] - def inverse(self, y: dict[str, Float]) -> dict[str, Float]: + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: """ Inverse transform the input y to original coordinate x. @@ -128,6 +128,8 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: ------- x : dict[str, Float] The original dictionary. + log_det : Float + The log Jacobian determinant. """ y_copy = y.copy() transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) @@ -145,7 +147,7 @@ def inverse(self, y: dict[str, Float]) -> dict[str, Float]: ) return y_copy, jacobian - def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + def backward(self, y: dict[str, Float]) -> dict[str, Float]: """ Pull back the input y to original coordinate x and return the log Jacobian determinant. @@ -158,8 +160,6 @@ def backward(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: ------- x : dict[str, Float] The original dictionary. - log_det : Float - The log Jacobian determinant. """ y_copy = y.copy() output_params = self.inverse_transform_func(y_copy)