Skip to content

Commit

Permalink
Merge pull request #123 from thomasckng/transform
Browse files Browse the repository at this point in the history
Use uniform in component mass in test
  • Loading branch information
kazewong authored Aug 2, 2024
2 parents 87440db + 15e4074 commit e1800da
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 35 deletions.
22 changes: 15 additions & 7 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,21 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict):
named_params = transform.forward(named_params)
return self.likelihood.evaluate(named_params, data) + prior

def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.prior.sample(key, self.sampler.n_chains)
for transform in self.sample_transforms:
initial_guess_named = jax.vmap(transform.forward)(initial_guess_named)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
self.sampler.sample(initial_guess, None) # type: ignore
def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
if initial_position.size == 0:
initial_guess = []
for _ in range(self.sampler.n_chains):
flag = True
while flag:
key = jax.random.split(key)[1]
guess = self.prior.sample(key, 1)
for transform in self.sample_transforms:
guess = transform.forward(guess)
guess = jnp.array([i for i in guess.values()]).T[0]
flag = not jnp.all(jnp.isfinite(guess))
initial_guess.append(guess)
initial_position = jnp.array(initial_guess)
self.sampler.sample(initial_position, None) # type: ignore

def maximize_likelihood(
self,
Expand Down
66 changes: 52 additions & 14 deletions src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from jimgw.single_event.detector import Detector
from jimgw.utils import log_i0
from jimgw.single_event.waveform import Waveform
from jimgw.transforms import BijectiveTransform, NtoMTransform


class SingleEventLiklihood(LikelihoodBase):
Expand Down Expand Up @@ -184,8 +185,6 @@ def __init__(
self,
detectors: list[Detector],
waveform: Waveform,
prior: Prior,
bounds: Float[Array, " n_dim 2"],
n_bins: int = 100,
trigger_time: float = 0,
duration: float = 4,
Expand All @@ -194,6 +193,9 @@ def __init__(
n_steps: int = 2000,
ref_params: dict = {},
reference_waveform: Optional[Waveform] = None,
prior: Optional[Prior] = None,
sample_transforms: list[BijectiveTransform] = [],
likelihood_transforms: list[NtoMTransform] = [],
**kwargs,
) -> None:
super().__init__(
Expand Down Expand Up @@ -254,17 +256,24 @@ def __init__(
)
self.freq_grid_low = freq_grid[:-1]

if not ref_params:
if ref_params:
self.ref_params = ref_params
print(f"Reference parameters provided, which are {self.ref_params}")
elif prior:
print("No reference parameters are provided, finding it...")
ref_params = self.maximize_likelihood(
bounds=bounds, prior=prior, popsize=popsize, n_steps=n_steps
prior=prior,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
popsize=popsize,
n_steps=n_steps,
)
self.ref_params = {key: float(value) for key, value in ref_params.items()}
print(f"The reference parameters are {self.ref_params}")
else:
self.ref_params = ref_params
print(f"Reference parameters provided, which are {self.ref_params}")

raise ValueError(
"Either reference parameters or parameter names must be provided"
)
# safe guard for the reference parameters
# since ripple cannot handle eta=0.25
if jnp.isclose(self.ref_params["eta"], 0.25):
Expand Down Expand Up @@ -542,25 +551,54 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center):
def maximize_likelihood(
self,
prior: Prior,
bounds: Float[Array, " n_dim 2"],
likelihood_transforms: list[NtoMTransform],
sample_transforms: list[BijectiveTransform],
popsize: int = 100,
n_steps: int = 2000,
):
parameter_names = prior.parameter_names
for transform in sample_transforms:
parameter_names = transform.propagate_name(parameter_names)

def y(x):
return -self.evaluate_original(prior.transform(prior.add_name(x)), {})
named_params = dict(zip(parameter_names, x))
for transform in reversed(sample_transforms):
named_params = transform.backward(named_params)
for transform in likelihood_transforms:
named_params = transform.forward(named_params)
return -self.evaluate_original(named_params, {})

print("Starting the optimizer")

optimizer = optimization_Adam(
n_steps=n_steps, learning_rate=0.001, noise_level=1
)
initial_position = jnp.array(
list(prior.sample(jax.random.PRNGKey(0), popsize).values())
).T

key = jax.random.PRNGKey(0)
initial_position = []
for _ in range(popsize):
flag = True
while flag:
key = jax.random.split(key)[1]
guess = prior.sample(key, 1)
for transform in sample_transforms:
guess = transform.forward(guess)
guess = jnp.array([i for i in guess.values()]).T[0]
flag = not jnp.all(jnp.isfinite(guess))
initial_position.append(guess)
initial_position = jnp.array(initial_position)
rng_key, optimized_positions, summary = optimizer.optimize(
jax.random.PRNGKey(12094), y, initial_position
)
best_fit = optimized_positions[jnp.nanargmin(summary["final_log_prob"])]
return prior.transform(prior.add_name(best_fit))

best_fit = optimized_positions[jnp.argmin(summary["final_log_prob"])]

named_params = dict(zip(parameter_names, best_fit))
for transform in reversed(sample_transforms):
named_params = transform.backward(named_params)
for transform in likelihood_transforms:
named_params = transform.forward(named_params)
return named_params


likelihood_presets = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import time

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam

Expand All @@ -19,8 +17,6 @@
########## First we grab data #############
###########################################

total_time_start = time.time()

# first, fetch a 4s segment centered on GW150914
gps = 1126259462.4
duration = 4
Expand All @@ -35,11 +31,13 @@
for ifo in ifos:
ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"])
q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"])
M_c_min, M_c_max = 10.0, 80.0
q_min, q_max = 0.125, 1.0
m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"])
m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"])
s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"])
s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"])
dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"])
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
iota_prior = SinePrior(parameter_names=["iota"])
Expand All @@ -49,8 +47,8 @@

prior = CombinePrior(
[
Mc_prior,
q_prior,
m_1_prior,
m_2_prior,
s1z_prior,
s2z_prior,
dL_prior,
Expand All @@ -64,8 +62,9 @@
)

sample_transforms = [
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.),
ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]),
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0),
Expand All @@ -79,7 +78,7 @@
]

likelihood_transforms = [
MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]),
ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]),
]

likelihood = TransientLikelihoodFD(
Expand Down
133 changes: 133 additions & 0 deletions test/integration/test_GW150914_D_heterodyne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)

###########################################
########## First we grab data #############
###########################################

# first, fetch a 4s segment centered on GW150914
gps = 1126259462.4
duration = 4
post_trigger_duration = 2
start_pad = duration - post_trigger_duration
end_pad = post_trigger_duration
fmin = 20.0
fmax = 1024.0

ifos = [H1, L1]

for ifo in ifos:
ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

M_c_min, M_c_max = 10.0, 80.0
q_min, q_max = 0.125, 1.0
m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"])
m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"])
s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"])
s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"])
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
iota_prior = SinePrior(parameter_names=["iota"])
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
dec_prior = CosinePrior(parameter_names=["dec"])

prior = CombinePrior(
[
m_1_prior,
m_2_prior,
s1z_prior,
s2z_prior,
dL_prior,
t_c_prior,
phase_c_prior,
iota_prior,
psi_prior,
ra_prior,
dec_prior,
]
)

sample_transforms = [
ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]),
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0),
BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05),
BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]),
]

likelihood = HeterodynedTransientLikelihoodFD(
ifos,
prior=prior,
waveform=RippleIMRPhenomD(),
trigger_time=gps,
duration=4,
post_trigger_duration=2,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
n_steps=5,
popsize=10,
)


mass_matrix = jnp.eye(11)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1)

n_epochs = 2
n_loop_training = 1
learning_rate = 1e-4


jim = Jim(
likelihood,
prior,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
n_loop_training=n_loop_training,
n_loop_production=1,
n_local_steps=5,
n_global_steps=5,
n_chains=4,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30,
n_flow_samples=100,
momentum=0.9,
batch_size=100,
use_global=True,
train_thinning=1,
output_thinning=1,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer, "default"],
)

jim.sample(jax.random.PRNGKey(42))

0 comments on commit e1800da

Please sign in to comment.