From 268b54f72a0fb643962b2bb1dacc4bb092f4eac0 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 28 Nov 2023 16:45:55 -0500 Subject: [PATCH 1/2] max_samples affect training speed a lot. Lower it a bit for user with less powerful GPU --- example/GW150914_PV2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index 140af05b..ac164357 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -101,16 +101,16 @@ jim = Jim( likelihood, prior, - n_loop_training=100, + n_loop_training=200, n_loop_production=10, n_local_steps=300, n_global_steps=300, n_chains=500, n_epochs=300, learning_rate=0.001, - max_samples = 60000, + max_samples = 10000, momentum=0.9, - batch_size=60000, + batch_size=10000, use_global=True, keep_quantile=0., train_thinning=1, From 84abc7aec9b1d88234379415397956cc23e3aec3 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 28 Nov 2023 16:46:22 -0500 Subject: [PATCH 2/2] Update new_global config --- example/GW150914_PV2_newglobal.py | 111 ++++++++++++++++++++++++------ 1 file changed, 91 insertions(+), 20 deletions(-) diff --git a/example/GW150914_PV2_newglobal.py b/example/GW150914_PV2_newglobal.py index d995df98..fce24f74 100644 --- a/example/GW150914_PV2_newglobal.py +++ b/example/GW150914_PV2_newglobal.py @@ -11,7 +11,7 @@ jax.config.update("jax_enable_x64", True) ########################################### -########## First we grab data ############# +########## This script is experimental #### ########################################### total_time_start = time.time() @@ -30,50 +30,121 @@ waveform = RippleIMRPhenomPv2(f_ref=20) -Mc_prior = Unconstrained_Uniform(10., 80., naming=["M_c"]) -q_prior = Unconstrained_Uniform(0.125, 1., naming=["q"], transforms={"q": ("eta", lambda params: params['q']/(1+params['q'])**2)}) +Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Unconstrained_Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) s1_prior = Sphere("s1") s2_prior = Sphere("s2") -dL_prior = Unconstrained_Uniform(0., 2000., naming=["d_L"]) +dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0., 2*jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform(-1., 1., naming=["cos_iota"], transforms={"cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi))}) -psi_prior = Unconstrained_Uniform(0., jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0., 2*jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform(-1., 1., naming=["sin_dec"], transforms={"sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))}) +phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) -prior = Composite([Mc_prior, q_prior, s1_prior, s2_prior, dL_prior, t_c_prior, phase_c_prior, cos_iota_prior, psi_prior, ra_prior, sin_dec_prior]) +prior = Composite( + [ + Mc_prior, + q_prior, + s1_prior, + s2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] +) -likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) +optimization_bounds = jnp.array( + [ + [-10.0, 10.0], + [-10.0, 10.0], + [0.0, 2.0 * jnp.pi], + [-1.0, 1.0], + [0.01, 1.0], + [0.0, 2.0 * jnp.pi], + [-1.0, 1.0], + [0.01, 1.0], + [-10.0, 10.0], + [-30.0, 30.0], + [-10.0, 10.0], + [-10.0, 10.0], + [-10.0, 10.0], + [-10.0, 10.0], + [-10.0, 10.0], + ] +) + +likelihood = TransientLikelihoodFD( + [H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2 +) mass_matrix = jnp.eye(prior.n_dim) -# mass_matrix = mass_matrix.at[1, 1].set(1e-3) -# mass_matrix = mass_matrix.at[9, 9].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[9, 9].set(1e-3) +mass_matrix = mass_matrix * 3e-3 +local_sampler_arg = {"step_size": mass_matrix} jim = Jim( likelihood, prior, - n_loop_training=50, + n_loop_training=20, n_loop_production=10, n_local_steps=300, n_global_steps=300, n_chains=500, n_epochs=300, learning_rate=0.001, - max_samples = 60000, + max_samples=60000, momentum=0.9, batch_size=30000, use_global=True, - keep_quantile=0., + keep_quantile=0.0, train_thinning=1, output_thinning=30, local_sampler_arg=local_sampler_arg, - num_layers = 6, - hidden_size = [32,32], - num_bins = 8 + num_layers=6, + hidden_size=[32, 32], + num_bins=8, + flowHMC_params={ + "step_size": 1e-2, + "n_leapfrog": 3, + "condition_matrix": jnp.linalg.inv(mass_matrix), + }, ) # jim.maximize_likelihood([prior.xmin, prior.xmax])