Skip to content

Commit

Permalink
Merge pull request kazewong#163 from kazewong/Fix-spin-reparameteriza…
Browse files Browse the repository at this point in the history
…tion

Fix spin reparameterization
  • Loading branch information
kazewong authored Oct 13, 2024
2 parents f391bd7 + c5957b4 commit d0554fa
Show file tree
Hide file tree
Showing 10 changed files with 533 additions and 409 deletions.
84 changes: 69 additions & 15 deletions example/GW150914_IMRPhenomD.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior
from jimgw.prior import (
CombinePrior,
UniformPrior,
CosinePrior,
SinePrior,
PowerLawPrior,
)
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
from jimgw.single_event.transforms import (
ComponentMassesToChirpMassSymmetricMassRatioTransform,
SkyFrameToDetectorFrameSkyPositionTransform,
ComponentMassesToChirpMassMassRatioTransform,
)
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam

Expand Down Expand Up @@ -65,18 +75,62 @@

sample_transforms = [
# ComponentMassesToChirpMassMassRatioTransform,
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = (["eta"], ["eta_unbounded"]), original_lower_bound=eta_min, original_upper_bound=eta_max),
BoundToUnbound(name_mapping = (["s1_z"], ["s1_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = (["s2_z"], ["s2_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=1.0, original_upper_bound=2000.0),
BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05),
BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]), original_lower_bound=0., original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(
name_mapping=(["M_c"], ["M_c_unbounded"]),
original_lower_bound=M_c_min,
original_upper_bound=M_c_max,
),
BoundToUnbound(
name_mapping=(["eta"], ["eta_unbounded"]),
original_lower_bound=eta_min,
original_upper_bound=eta_max,
),
BoundToUnbound(
name_mapping=(["s1_z"], ["s1_z_unbounded"]),
original_lower_bound=-1.0,
original_upper_bound=1.0,
),
BoundToUnbound(
name_mapping=(["s2_z"], ["s2_z_unbounded"]),
original_lower_bound=-1.0,
original_upper_bound=1.0,
),
BoundToUnbound(
name_mapping=(["d_L"], ["d_L_unbounded"]),
original_lower_bound=1.0,
original_upper_bound=2000.0,
),
BoundToUnbound(
name_mapping=(["t_c"], ["t_c_unbounded"]),
original_lower_bound=-0.05,
original_upper_bound=0.05,
),
BoundToUnbound(
name_mapping=(["phase_c"], ["phase_c_unbounded"]),
original_lower_bound=0.0,
original_upper_bound=2 * jnp.pi,
),
BoundToUnbound(
name_mapping=(["iota"], ["iota_unbounded"]),
original_lower_bound=0.0,
original_upper_bound=jnp.pi,
),
BoundToUnbound(
name_mapping=(["psi"], ["psi_unbounded"]),
original_lower_bound=0.0,
original_upper_bound=jnp.pi,
),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(
name_mapping=(["zenith"], ["zenith_unbounded"]),
original_lower_bound=0.0,
original_upper_bound=jnp.pi,
),
BoundToUnbound(
name_mapping=(["azimuth"], ["azimuth_unbounded"]),
original_lower_bound=0.0,
original_upper_bound=2 * jnp.pi,
),
]

likelihood_transforms = [
Expand Down Expand Up @@ -125,9 +179,9 @@
output_thinning=10,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer, "default"],
verbose=True
verbose=True,
)

jim.sample(jax.random.PRNGKey(42))
# jim.get_samples()
# jim.print_summary()
# jim.print_summary()
173 changes: 173 additions & 0 deletions example/GW150914_IMRPhenomPV2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import time

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.jim import Jim
from jimgw.prior import (
CombinePrior,
UniformPrior,
CosinePrior,
SinePrior,
PowerLawPrior,
UniformSpherePrior,
)
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomPv2
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import (
SkyFrameToDetectorFrameSkyPositionTransform,
SphereSpinToCartesianSpinTransform,
MassRatioToSymmetricMassRatioTransform,
DistanceToSNRWeightedDistanceTransform,
GeocentricArrivalTimeToDetectorArrivalTimeTransform,
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform,
)
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)

###########################################
########## First we grab data #############
###########################################

total_time_start = time.time()

# first, fetch a 4s segment centered on GW150914
gps = 1126259462.4
start = gps - 2
end = gps + 2
fmin = 20.0
fmax = 1024.0

ifos = [H1, L1]

H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

waveform = RippleIMRPhenomPv2(f_ref=20)

###########################################
########## Set up priors ##################
###########################################

prior = []

# Mass prior
M_c_min, M_c_max = 10.0, 80.0
q_min, q_max = 0.125, 1.0
Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"])
q_prior = UniformPrior(q_min, q_max, parameter_names=["q"])

prior = prior + [Mc_prior, q_prior]

# Spin prior
s1_prior = UniformSpherePrior(parameter_names=["s1"])
s2_prior = UniformSpherePrior(parameter_names=["s2"])
iota_prior = SinePrior(parameter_names=["iota"])

prior = prior + [
s1_prior,
s2_prior,
iota_prior,
]

# Extrinsic prior
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
dec_prior = CosinePrior(parameter_names=["dec"])

prior = prior + [
dL_prior,
t_c_prior,
phase_c_prior,
psi_prior,
ra_prior,
dec_prior,
]

prior = CombinePrior(prior)

# Defining Transforms

sample_transforms = [
DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
MassRatioToSymmetricMassRatioTransform,
SphereSpinToCartesianSpinTransform("s1"),
SphereSpinToCartesianSpinTransform("s2"),
]


likelihood = TransientLikelihoodFD(
[H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2
)


mass_matrix = jnp.eye(prior.n_dim)
# mass_matrix = mass_matrix.at[1, 1].set(1e-3)
# mass_matrix = mass_matrix.at[9, 9].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 1e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)

import optax

n_epochs = 20
n_loop_training = 100
total_epochs = n_epochs * n_loop_training
start = total_epochs // 10
learning_rate = optax.polynomial_schedule(
1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start
)

jim = Jim(
likelihood,
prior,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
n_loop_training=n_loop_training,
n_loop_production=20,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_sample=100000,
momentum=0.9,
batch_size=30000,
use_global=True,
keep_quantile=0.0,
train_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
# strategies=[Adam_optimizer,"default"],
)


jim.sample(jax.random.PRNGKey(42))
Loading

0 comments on commit d0554fa

Please sign in to comment.