Skip to content

Commit 24235c2

Browse files
authored
Merge pull request #152 from kazewong/150-slow-initialize-point-generation
150 slow initialize point generation
2 parents d2c0416 + 93b6483 commit 24235c2

File tree

3 files changed

+146
-150
lines changed

3 files changed

+146
-150
lines changed

example/GW150914.py

-138
This file was deleted.

example/GW150914_IMRPhenomD.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
from jimgw.jim import Jim
5+
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior
6+
from jimgw.single_event.detector import H1, L1
7+
from jimgw.single_event.likelihood import TransientLikelihoodFD
8+
from jimgw.single_event.waveform import RippleIMRPhenomD
9+
from jimgw.transforms import BoundToUnbound
10+
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
11+
from jimgw.single_event.utils import Mc_q_to_m1_m2
12+
from flowMC.strategy.optimization import optimization_Adam
13+
14+
jax.config.update("jax_enable_x64", True)
15+
16+
###########################################
17+
########## First we grab data #############
18+
###########################################
19+
20+
# first, fetch a 4s segment centered on GW150914
21+
gps = 1126259462.4
22+
duration = 4
23+
post_trigger_duration = 2
24+
start_pad = duration - post_trigger_duration
25+
end_pad = post_trigger_duration
26+
fmin = 20.0
27+
fmax = 1024.0
28+
29+
ifos = [H1, L1]
30+
31+
for ifo in ifos:
32+
ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
33+
34+
M_c_min, M_c_max = 10.0, 80.0
35+
eta_min, eta_max = 0.2, 0.25
36+
# m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"])
37+
# m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"])
38+
Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"])
39+
eta_prior = UniformPrior(eta_min, eta_max, parameter_names=["eta"])
40+
s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"])
41+
s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"])
42+
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
43+
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
44+
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
45+
iota_prior = SinePrior(parameter_names=["iota"])
46+
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
47+
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
48+
dec_prior = CosinePrior(parameter_names=["dec"])
49+
50+
prior = CombinePrior(
51+
[
52+
Mc_prior,
53+
eta_prior,
54+
s1z_prior,
55+
s2z_prior,
56+
dL_prior,
57+
t_c_prior,
58+
phase_c_prior,
59+
iota_prior,
60+
psi_prior,
61+
ra_prior,
62+
dec_prior,
63+
]
64+
)
65+
66+
sample_transforms = [
67+
# ComponentMassesToChirpMassMassRatioTransform,
68+
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
69+
BoundToUnbound(name_mapping = (["eta"], ["eta_unbounded"]), original_lower_bound=eta_min, original_upper_bound=eta_max),
70+
BoundToUnbound(name_mapping = (["s1_z"], ["s1_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0),
71+
BoundToUnbound(name_mapping = (["s2_z"], ["s2_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0),
72+
BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=1.0, original_upper_bound=2000.0),
73+
BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05),
74+
BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
75+
BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]), original_lower_bound=0., original_upper_bound=jnp.pi),
76+
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
77+
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
78+
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
79+
BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
80+
]
81+
82+
likelihood_transforms = [
83+
# ComponentMassesToChirpMassSymmetricMassRatioTransform,
84+
]
85+
86+
likelihood = TransientLikelihoodFD(
87+
ifos,
88+
waveform=RippleIMRPhenomD(),
89+
trigger_time=gps,
90+
duration=4,
91+
post_trigger_duration=2,
92+
)
93+
94+
95+
mass_matrix = jnp.eye(11)
96+
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
97+
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
98+
local_sampler_arg = {"step_size": mass_matrix * 3e-3}
99+
100+
Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)
101+
102+
n_epochs = 30
103+
n_loop_training = 20
104+
learning_rate = 1e-4
105+
106+
107+
jim = Jim(
108+
likelihood,
109+
prior,
110+
sample_transforms=sample_transforms,
111+
likelihood_transforms=likelihood_transforms,
112+
n_loop_training=n_loop_training,
113+
n_loop_production=20,
114+
n_local_steps=10,
115+
n_global_steps=1000,
116+
n_chains=500,
117+
n_epochs=n_epochs,
118+
learning_rate=learning_rate,
119+
n_max_examples=30000,
120+
n_flow_samples=100000,
121+
momentum=0.9,
122+
batch_size=30000,
123+
use_global=True,
124+
train_thinning=1,
125+
output_thinning=10,
126+
local_sampler_arg=local_sampler_arg,
127+
strategies=[Adam_optimizer, "default"],
128+
verbose=True
129+
)
130+
131+
jim.sample(jax.random.PRNGKey(42))
132+
# jim.get_samples()
133+
# jim.print_summary()

src/jimgw/jim.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,19 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict):
104104

105105
def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
106106
if initial_position.size == 0:
107-
initial_guess = []
108-
for _ in range(self.sampler.n_chains):
109-
flag = True
110-
while flag:
111-
key = jax.random.split(key)[1]
112-
guess = self.prior.sample(key, 1)
113-
for transform in self.sample_transforms:
114-
guess = transform.forward(guess)
115-
guess = jnp.array([i for i in guess.values()]).T[0]
116-
flag = not jnp.all(jnp.isfinite(guess))
117-
initial_guess.append(guess)
118-
initial_position = jnp.array(initial_guess)
107+
initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan
108+
109+
while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all():
110+
non_finite_index = jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1)
111+
112+
key, subkey = jax.random.split(key)
113+
guess = self.prior.sample(subkey, self.sampler.n_chains)
114+
for transform in self.sample_transforms:
115+
guess = jax.vmap(transform.forward)(guess)
116+
guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T
117+
finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0]
118+
common_length = min(len(finite_guess), len(non_finite_index))
119+
initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length])
119120
self.sampler.sample(initial_position, None) # type: ignore
120121

121122
def maximize_likelihood(

0 commit comments

Comments
 (0)