1
+ import jax
2
+ import jax .numpy as jnp
3
+
4
+ from jimgw .jim import Jim
5
+ from jimgw .prior import CombinePrior , UniformPrior , CosinePrior , SinePrior , PowerLawPrior
6
+ from jimgw .single_event .detector import H1 , L1
7
+ from jimgw .single_event .likelihood import TransientLikelihoodFD
8
+ from jimgw .single_event .waveform import RippleIMRPhenomD
9
+ from jimgw .transforms import BoundToUnbound
10
+ from jimgw .single_event .transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform , SkyFrameToDetectorFrameSkyPositionTransform , ComponentMassesToChirpMassMassRatioTransform
11
+ from jimgw .single_event .utils import Mc_q_to_m1_m2
12
+ from flowMC .strategy .optimization import optimization_Adam
13
+
14
+ jax .config .update ("jax_enable_x64" , True )
15
+
16
+ ###########################################
17
+ ########## First we grab data #############
18
+ ###########################################
19
+
20
+ # first, fetch a 4s segment centered on GW150914
21
+ gps = 1126259462.4
22
+ duration = 4
23
+ post_trigger_duration = 2
24
+ start_pad = duration - post_trigger_duration
25
+ end_pad = post_trigger_duration
26
+ fmin = 20.0
27
+ fmax = 1024.0
28
+
29
+ ifos = [H1 , L1 ]
30
+
31
+ for ifo in ifos :
32
+ ifo .load_data (gps , start_pad , end_pad , fmin , fmax , psd_pad = 16 , tukey_alpha = 0.2 )
33
+
34
+ M_c_min , M_c_max = 10.0 , 80.0
35
+ eta_min , eta_max = 0.2 , 0.25
36
+ # m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"])
37
+ # m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"])
38
+ Mc_prior = UniformPrior (M_c_min , M_c_max , parameter_names = ["M_c" ])
39
+ eta_prior = UniformPrior (eta_min , eta_max , parameter_names = ["eta" ])
40
+ s1z_prior = UniformPrior (- 1.0 , 1.0 , parameter_names = ["s1_z" ])
41
+ s2z_prior = UniformPrior (- 1.0 , 1.0 , parameter_names = ["s2_z" ])
42
+ dL_prior = PowerLawPrior (1.0 , 2000.0 , 2.0 , parameter_names = ["d_L" ])
43
+ t_c_prior = UniformPrior (- 0.05 , 0.05 , parameter_names = ["t_c" ])
44
+ phase_c_prior = UniformPrior (0.0 , 2 * jnp .pi , parameter_names = ["phase_c" ])
45
+ iota_prior = SinePrior (parameter_names = ["iota" ])
46
+ psi_prior = UniformPrior (0.0 , jnp .pi , parameter_names = ["psi" ])
47
+ ra_prior = UniformPrior (0.0 , 2 * jnp .pi , parameter_names = ["ra" ])
48
+ dec_prior = CosinePrior (parameter_names = ["dec" ])
49
+
50
+ prior = CombinePrior (
51
+ [
52
+ Mc_prior ,
53
+ eta_prior ,
54
+ s1z_prior ,
55
+ s2z_prior ,
56
+ dL_prior ,
57
+ t_c_prior ,
58
+ phase_c_prior ,
59
+ iota_prior ,
60
+ psi_prior ,
61
+ ra_prior ,
62
+ dec_prior ,
63
+ ]
64
+ )
65
+
66
+ sample_transforms = [
67
+ # ComponentMassesToChirpMassMassRatioTransform,
68
+ BoundToUnbound (name_mapping = (["M_c" ], ["M_c_unbounded" ]), original_lower_bound = M_c_min , original_upper_bound = M_c_max ),
69
+ BoundToUnbound (name_mapping = (["eta" ], ["eta_unbounded" ]), original_lower_bound = eta_min , original_upper_bound = eta_max ),
70
+ BoundToUnbound (name_mapping = (["s1_z" ], ["s1_z_unbounded" ]) , original_lower_bound = - 1.0 , original_upper_bound = 1.0 ),
71
+ BoundToUnbound (name_mapping = (["s2_z" ], ["s2_z_unbounded" ]) , original_lower_bound = - 1.0 , original_upper_bound = 1.0 ),
72
+ BoundToUnbound (name_mapping = (["d_L" ], ["d_L_unbounded" ]) , original_lower_bound = 1.0 , original_upper_bound = 2000.0 ),
73
+ BoundToUnbound (name_mapping = (["t_c" ], ["t_c_unbounded" ]) , original_lower_bound = - 0.05 , original_upper_bound = 0.05 ),
74
+ BoundToUnbound (name_mapping = (["phase_c" ], ["phase_c_unbounded" ]) , original_lower_bound = 0.0 , original_upper_bound = 2 * jnp .pi ),
75
+ BoundToUnbound (name_mapping = (["iota" ], ["iota_unbounded" ]), original_lower_bound = 0. , original_upper_bound = jnp .pi ),
76
+ BoundToUnbound (name_mapping = (["psi" ], ["psi_unbounded" ]), original_lower_bound = 0.0 , original_upper_bound = jnp .pi ),
77
+ SkyFrameToDetectorFrameSkyPositionTransform (gps_time = gps , ifos = ifos ),
78
+ BoundToUnbound (name_mapping = (["zenith" ], ["zenith_unbounded" ]), original_lower_bound = 0.0 , original_upper_bound = jnp .pi ),
79
+ BoundToUnbound (name_mapping = (["azimuth" ], ["azimuth_unbounded" ]), original_lower_bound = 0.0 , original_upper_bound = 2 * jnp .pi ),
80
+ ]
81
+
82
+ likelihood_transforms = [
83
+ # ComponentMassesToChirpMassSymmetricMassRatioTransform,
84
+ ]
85
+
86
+ likelihood = TransientLikelihoodFD (
87
+ ifos ,
88
+ waveform = RippleIMRPhenomD (),
89
+ trigger_time = gps ,
90
+ duration = 4 ,
91
+ post_trigger_duration = 2 ,
92
+ )
93
+
94
+
95
+ mass_matrix = jnp .eye (11 )
96
+ mass_matrix = mass_matrix .at [1 , 1 ].set (1e-3 )
97
+ mass_matrix = mass_matrix .at [5 , 5 ].set (1e-3 )
98
+ local_sampler_arg = {"step_size" : mass_matrix * 3e-3 }
99
+
100
+ Adam_optimizer = optimization_Adam (n_steps = 3000 , learning_rate = 0.01 , noise_level = 1 )
101
+
102
+ n_epochs = 30
103
+ n_loop_training = 20
104
+ learning_rate = 1e-4
105
+
106
+
107
+ jim = Jim (
108
+ likelihood ,
109
+ prior ,
110
+ sample_transforms = sample_transforms ,
111
+ likelihood_transforms = likelihood_transforms ,
112
+ n_loop_training = n_loop_training ,
113
+ n_loop_production = 20 ,
114
+ n_local_steps = 10 ,
115
+ n_global_steps = 1000 ,
116
+ n_chains = 500 ,
117
+ n_epochs = n_epochs ,
118
+ learning_rate = learning_rate ,
119
+ n_max_examples = 30000 ,
120
+ n_flow_samples = 100000 ,
121
+ momentum = 0.9 ,
122
+ batch_size = 30000 ,
123
+ use_global = True ,
124
+ train_thinning = 1 ,
125
+ output_thinning = 10 ,
126
+ local_sampler_arg = local_sampler_arg ,
127
+ strategies = [Adam_optimizer , "default" ],
128
+ verbose = True
129
+ )
130
+
131
+ jim .sample (jax .random .PRNGKey (42 ))
132
+ # jim.get_samples()
133
+ # jim.print_summary()
0 commit comments