Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reparameterization of extrinsic parameter for better sampling efficiency #131

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
f33a782
Adding transform from geocentric arrival time to detector arrival time
tsunhopang Aug 12, 2024
3505394
Adding transform from distance to SNR weighted distance
tsunhopang Aug 12, 2024
df75ceb
updating the typing for object attributes
tsunhopang Aug 13, 2024
9f2f52b
Adding geocentric phase to detector phase
tsunhopang Aug 13, 2024
b62970f
Adding ZeroLikelihood for testing purpose
tsunhopang Aug 13, 2024
4ea3322
Adding the missing mode 2pi for phasing transform
tsunhopang Aug 13, 2024
7a4bae0
Test wip
tsunhopang Aug 13, 2024
d5f86e5
Phase renaming
tsunhopang Aug 13, 2024
0a2e68c
wip
Aug 13, 2024
b96512c
Push conditional bijective transform
kazewong Aug 14, 2024
526e33c
Switch to using conditional transform
Aug 14, 2024
dbf3f30
Switch to using conditional transform
Aug 14, 2024
a375361
Fixing jacobian handling
Aug 14, 2024
d79af97
Both arrival phase and time transform are fully vectorized
Aug 16, 2024
bcbcbe2
Shifting distance transform to conditional
Aug 16, 2024
8dab27b
update example
Aug 16, 2024
fd33882
Fixing the single sided unbound transform
Aug 16, 2024
03e76dc
Update extrinsic test
Aug 17, 2024
8fe4b5f
bugfix for single sided transform
tsunhopang Aug 19, 2024
a19b556
Update test
Aug 19, 2024
6d2cd97
update distance transform
tsunhopang Aug 19, 2024
6993dd9
Update test
Aug 19, 2024
ff65fcf
Update arrival time transform
tsunhopang Aug 19, 2024
b98d783
Update test
Aug 19, 2024
e399a5e
Fix typo
tsunhopang Aug 19, 2024
583b759
Fix typo
tsunhopang Aug 19, 2024
b1133d3
Update runManager.py
xuyuon Aug 22, 2024
2a9d696
Update runManager.py
xuyuon Aug 22, 2024
cd559b6
Adding docstring for zerolikelihood
tsunhopang Aug 22, 2024
0c693d3
Merge pull request #9 from tsunhopang/extrinsic_parameter_sampling_im…
xuyuon Aug 23, 2024
5d6a795
Added run script
xuyuon Aug 25, 2024
a844170
Added run script
xuyuon Aug 25, 2024
bfa4e47
Added run script
xuyuon Aug 26, 2024
32b8e2e
Added run script
xuyuon Aug 26, 2024
041cf2a
Added run script
xuyuon Aug 26, 2024
0803ae8
Added run script on gw200112
xuyuon Aug 26, 2024
d932183
Added functions to calculate iota
xuyuon Aug 26, 2024
7b038ac
Updated transforms.py
xuyuon Aug 26, 2024
c2f0c79
Updated transforms.py
xuyuon Aug 26, 2024
1273246
Updated GW150914_Pv2.py
xuyuon Aug 26, 2024
0e181d8
Updated GW150914_Pv2_reparam.py
xuyuon Aug 26, 2024
79bdb1f
Updated GW150914_Pv2_reparam.py
xuyuon Aug 26, 2024
65a15c4
Updated transforms.py
xuyuon Aug 26, 2024
6a4fcaa
Updated utils.py
xuyuon Aug 26, 2024
2c440a3
Updated jim.py
xuyuon Aug 26, 2024
39c5364
Added run script
xuyuon Aug 26, 2024
7457c56
Fixing phase inverse transformation
tsunhopang Aug 26, 2024
ddfafc1
Added run script
xuyuon Aug 26, 2024
6291682
Merge branch 'jim-dev' into extrinsic_parameter_sampling_improvement
tsunhopang Sep 2, 2024
86605ea
Remove duplicated import
tsunhopang Sep 2, 2024
2fbfc04
Adding the named_Mc_q_to_m1_m2 back
tsunhopang Sep 2, 2024
ce7b308
Hard-code transform name_mapping and conditional_parameters
tsunhopang Sep 2, 2024
f80a87c
Merge branch 'extrinsic-parameter-sampling-improvement' into extrinsi…
xuyuon Sep 3, 2024
c4995d7
setting phiRef to 0 for getting the iota
tsunhopang Sep 5, 2024
762b7e0
Remove if-else function
tsunhopang Sep 5, 2024
8805205
Revert "Remove if-else function"
tsunhopang Sep 5, 2024
5ce686c
Improve the precession handling for extrinsic transform
tsunhopang Sep 5, 2024
de8ca6f
Update GW150914_D script
kazewong Sep 5, 2024
ce91d2c
update xample
kazewong Sep 10, 2024
94ea9fd
fix typing mistake
kazewong Sep 10, 2024
641d83f
Merge branch 'jim-dev' into extrinsic_parameter_sampling_improvement
tsunhopang Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions example/GW150914_D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)

###########################################
########## First we grab data #############
###########################################

# first, fetch a 4s segment centered on GW150914
gps = 1126259462.4
duration = 4
post_trigger_duration = 2
start_pad = duration - post_trigger_duration
end_pad = post_trigger_duration
fmin = 20.0
fmax = 1024.0

ifos = [H1, L1]

for ifo in ifos:
ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

M_c_min, M_c_max = 10.0, 80.0
q_min, q_max = 0.125, 1.0
m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"])
m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"])
s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"])
s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"])
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
iota_prior = SinePrior(parameter_names=["iota"])
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
dec_prior = CosinePrior(parameter_names=["dec"])

prior = CombinePrior(
[
m_1_prior,
m_2_prior,
s1z_prior,
s2z_prior,
dL_prior,
t_c_prior,
phase_c_prior,
iota_prior,
psi_prior,
ra_prior,
dec_prior,
]
)

sample_transforms = [
ComponentMassesToChirpMassMassRatioTransform,
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = (["s1_z"], ["s1_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = (["s2_z"], ["s2_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=1.0, original_upper_bound=2000.0),
BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05),
BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]), original_lower_bound=0., original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
ComponentMassesToChirpMassSymmetricMassRatioTransform,
]

likelihood = TransientLikelihoodFD(
ifos,
waveform=RippleIMRPhenomD(),
trigger_time=gps,
duration=4,
post_trigger_duration=2,
)


mass_matrix = jnp.eye(11)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)

n_epochs = 30
n_loop_training = 20
learning_rate = 1e-4


jim = Jim(
likelihood,
prior,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
n_loop_training=n_loop_training,
n_loop_production=20,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_samples=100000,
momentum=0.9,
batch_size=30000,
use_global=True,
train_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer, "default"],
verbose=True
)

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()
161 changes: 161 additions & 0 deletions example/GW150914_D_reparam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)

###########################################
########## First we grab data #############
###########################################

# first, fetch a 4s segment centered on GW150914
gps = 1126259462.4
duration = 4
post_trigger_duration = 2
start_pad = duration - post_trigger_duration
end_pad = post_trigger_duration
fmin = 20.0
fmax = 1024.0

ifos = [H1, L1]

for ifo in ifos:
ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

M_c_min, M_c_max = 10.0, 80.0
q_min, q_max = 0.125, 1.0
m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"])
m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"])
s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"])
s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"])
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
iota_prior = SinePrior(parameter_names=["iota"])
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
dec_prior = CosinePrior(parameter_names=["dec"])

prior = CombinePrior(
[
m_1_prior,
m_2_prior,
s1z_prior,
s2z_prior,
dL_prior,
t_c_prior,
phase_c_prior,
iota_prior,
psi_prior,
ra_prior,
dec_prior,
]
)

sample_transforms = [
# all the user reparametrization transform
ComponentMassesToChirpMassMassRatioTransform,
DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
# all the bound to unbound transform
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
ComponentMassesToChirpMassSymmetricMassRatioTransform,
]

likelihood = TransientLikelihoodFD(
ifos,
waveform=RippleIMRPhenomD(),
trigger_time=gps,
duration=4,
post_trigger_duration=2,
)


mass_matrix = jnp.eye(11)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)

n_epochs = 30
n_loop_training = 100
learning_rate = 1e-4


jim = Jim(
likelihood,
prior,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
n_loop_training=n_loop_training,
n_loop_production=20,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_samples=100000,
momentum=0.9,
batch_size=30000,
use_global=True,
train_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer, "default"],
)

jim.sample(jax.random.PRNGKey(42))
#jim.get_samples()
#jim.print_summary()


###########################################
########## Visualize the Data #############
###########################################
import corner
import matplotlib.pyplot as plt
import numpy as np

production_summary = jim.sampler.get_sampler_state(training=False)
production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T
if jim.sample_transforms:
transformed_chain = jim.add_name(production_chain)
for transform in reversed(jim.sample_transforms):
transformed_chain = transform.backward(transformed_chain)
result = transformed_chain
labels = list(transformed_chain.keys())

samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array
transposed_array = samples.T # transpose the array
figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True)
plt.savefig("GW1500914_D_reparam.jpeg")

############################################
############## Save the Run ################
############################################
#import pickle
#pickle.dump(result, open("GW150914_D_reparam.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL)
Loading
Loading