Skip to content

Commit

Permalink
Merge pull request #5 from kazewong/98-moving-naming-tracking-into-ji…
Browse files Browse the repository at this point in the history
…m-class-from-prior-class

Sync
  • Loading branch information
thomasckng authored Jul 31, 2024
2 parents 39126f5 + 46bd044 commit a0161ee
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 35 deletions.
35 changes: 19 additions & 16 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped

from jimgw.transforms import (
Transform,
NtoNTransform,
BijectiveTransform,
LogitTransform,
ScaleTransform,
OffsetTransform,
ArcSineTransform,
PowerLawTransform,
ParetoTransform,
# PowerLawTransform,
# ParetoTransform,
)


Expand Down Expand Up @@ -61,7 +60,7 @@ def sample(
) -> dict[str, Float[Array, " n_samples"]]:
raise NotImplementedError

def log_prob(self, x: dict[str, Array]) -> Float:
def log_prob(self, z: dict[str, Array]) -> Float:
raise NotImplementedError


Expand Down Expand Up @@ -99,7 +98,7 @@ def sample(
samples = jnp.log(samples / (1 - samples))
return self.add_name(samples[None])

def log_prob(self, x: dict[str, Float]) -> Float:
def log_prob(self, z: dict[str, Float]) -> Float:
variable = x[self.parameter_names[0]]
return -variable - 2 * jnp.log(1 + jnp.exp(-variable))

Expand Down Expand Up @@ -139,26 +138,28 @@ def sample(
samples = jax.random.normal(rng_key, (n_samples,))
return self.add_name(samples[None])

def log_prob(self, x: dict[str, Float]) -> Float:
def log_prob(self, z: dict[str, Float]) -> Float:
variable = x[self.parameter_names[0]]
return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi)


class SequentialTransformPrior(Prior):
"""
Transform a prior distribution by applying a sequence of transforms.
The space before the transform is named as x,
and the space after the transform is named as z
"""

base_prior: Prior
transforms: list[NtoNTransform]
transforms: list[BijectiveTransform]

def __repr__(self):
return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})"

def __init__(
self,
base_prior: Prior,
transforms: list[NtoNTransform],
transforms: list[BijectiveTransform],
):

self.base_prior = base_prior
Expand All @@ -174,14 +175,16 @@ def sample(
output = self.base_prior.sample(rng_key, n_samples)
return jax.vmap(self.transform)(output)

def log_prob(self, x: dict[str, Float]) -> Float:
def log_prob(self, z: dict[str, Float]) -> Float:
"""
log_prob has to be evaluated in the space of the base_prior.
Evaluating the probability of the transformed variable z.
This is what flowMC should sample from
"""
output = self.base_prior.log_prob(x)
for transform in self.transforms:
x, log_jacobian = transform.transform(x)
output = 0
for transform in reversed(self.transforms):
z, log_jacobian = transform.inverse(z)
output -= log_jacobian
output += self.base_prior.log_prob(z)
return output

def transform(self, x: dict[str, Float]) -> dict[str, Float]:
Expand Down Expand Up @@ -223,10 +226,10 @@ def sample(
output.update(prior.sample(subkey, n_samples))
return output

def log_prob(self, x: dict[str, Float]) -> Float:
def log_prob(self, z: dict[str, Float]) -> Float:
output = 0.0
for prior in self.base_prior:
output += prior.log_prob(x)
output += prior.log_prob(z)
return output


Expand Down
4 changes: 2 additions & 2 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def __init__(
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
self.transform_func = lambda x: jnp.arcsin(x)
self.inverse_transform_func = lambda x: jnp.sin(x)
self.transform_func = lambda x: [jnp.arcsin(x[0])]
self.inverse_transform_func = lambda x: [jnp.sin(x[0])]


# class PowerLawTransform(UnivariateTransform):
Expand Down
26 changes: 9 additions & 17 deletions test/integration/test_GW150914.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
Expand Down Expand Up @@ -33,43 +33,35 @@
L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"])
q_prior = UniformPrior(
eta_prior = UniformPrior(
0.125,
1.0,
parameter_names=["q"], # Need name transformation in likelihood to work
0.25,
parameter_names=["eta"], # Need name transformation in likelihood to work
)
s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"])
s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"])
# Current likelihood sampling will fail and give nan because of large number
dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
cos_iota_prior = UniformPrior(
-1.0,
1.0,
parameter_names=["cos_iota"], # Need name transformation in likelihood to work
)
iota_prior = CosinePrior(parameter_names=["iota"])
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
sin_dec_prior = UniformPrior(
-1.0,
1.0,
parameter_names=["sin_dec"], # Need name transformation in likelihood to work
)
dec_prior = SinePrior(parameter_names=["dec"])

prior = CombinePrior(
[
Mc_prior,
q_prior,
eta_prior,
s1z_prior,
s2z_prior,
dL_prior,
t_c_prior,
phase_c_prior,
cos_iota_prior,
iota_prior,
psi_prior,
ra_prior,
sin_dec_prior,
dec_prior,
]
)
likelihood = TransientLikelihoodFD(
Expand Down

0 comments on commit a0161ee

Please sign in to comment.