Skip to content

Commit

Permalink
Merge pull request #59 from kazewong/12-batch-job-submission
Browse files Browse the repository at this point in the history
12 batch job submission
  • Loading branch information
kazewong authored Jan 3, 2024
2 parents 0a45b88 + e889ec2 commit de17d53
Show file tree
Hide file tree
Showing 22 changed files with 728 additions and 183 deletions.
4 changes: 4 additions & 0 deletions docs/tutorials/anatomy_of_jim.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ At its core, `flowMC` is still a MCMC algorithm, so the hyperparameter tuning is
1. If you can, use more chains, especially on a GPU. Bring the number of chains up until you start to get significant performance hit or run out of memory.
2. Run it longer, in particular the training phase. In fact, most of the computation cost goes into the training part, once you get a reasonably tuned normalizing flow model, the production phase is usually quite cheap. To be concrete, blow `n_loop_training` up until you cannot stand how slow it is.

## Run Manager

While Jim is the main object that will handle most of the work, there are a lot of bookkeeping that needs to be done around a run.

## Analysis
9 changes: 6 additions & 3 deletions example/GW150914.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import time
from jimgw.jim import Jim
from jimgw.detector import H1, L1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomD
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import (
HeterodynedTransientLikelihoodFD,
TransientLikelihoodFD,
)
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.prior import Unconstrained_Uniform, Composite
import jax.numpy as jnp
import jax
Expand Down
6 changes: 3 additions & 3 deletions example/GW150914_PV2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from jimgw.jim import Jim
from jimgw.detector import H1, L1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomPv2
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomPv2
from jimgw.prior import Uniform, Composite, Sphere
import jax.numpy as jnp
import jax
Expand Down
6 changes: 3 additions & 3 deletions example/GW150914_PV2_newglobal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from jimgw.jim import Jim
from jimgw.detector import H1, L1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2
from jimgw.prior import Uniform, Unconstrained_Uniform, Composite, Sphere
import jax.numpy as jnp
import jax
Expand Down
6 changes: 3 additions & 3 deletions example/GW150914_heterodyne.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from jimgw.jim import Jim
from jimgw.detector import H1, L1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomD
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.prior import Uniform, Composite
import jax.numpy as jnp
import jax
Expand Down
6 changes: 3 additions & 3 deletions example/GW170817.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from jimgw.jim import Jim
from jimgw.detector import H1, L1, V1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomD
from jimgw.single_event.detector import H1, L1, V1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.prior import Uniform
from gwosc.datasets import event_gps
import jax.numpy as jnp
Expand Down
72 changes: 44 additions & 28 deletions example/GW170817_heterodyne.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from jimgw.jim import Jim
from jimgw.detector import H1, L1, V1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomD
from jimgw.prior import Uniform, Powerlaw, Composite
from jimgw.single_event.detector import H1, L1, V1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.prior import Uniform, PowerLaw, Composite
import jax.numpy as jnp
import jax
import time

jax.config.update("jax_enable_x64", True)


Expand All @@ -22,14 +23,20 @@
duration = T
post_trigger_duration = 2
epoch = duration - post_trigger_duration
f_ref = fmin
f_ref = fmin

### Getting ifos and overwriting with above data

tukey_alpha = 2 / (duration / 2)
H1.load_data(gps, duration, 2, fmin, fmax, psd_pad=duration+16, tukey_alpha=tukey_alpha)
L1.load_data(gps, duration, 2, fmin, fmax, psd_pad=duration+16, tukey_alpha=tukey_alpha)
V1.load_data(gps, duration, 2, fmin, fmax, psd_pad=duration+16, tukey_alpha=tukey_alpha)
H1.load_data(
gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha
)
L1.load_data(
gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha
)
V1.load_data(
gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha
)

### Define priors

Expand All @@ -41,13 +48,13 @@
naming=["q"],
transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)},
)
s1z_prior = Uniform(-0.05, 0.05, naming=["s1_z"])
s2z_prior = Uniform(-0.05, 0.05, naming=["s2_z"])
s1z_prior = Uniform(-0.05, 0.05, naming=["s1_z"])
s2z_prior = Uniform(-0.05, 0.05, naming=["s2_z"])

# External parameters
dL_prior = Powerlaw(1.0, 75.0, 2.0, naming=["d_L"])
t_c_prior = Uniform(-0.1, 0.1, naming=["t_c"])
phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"])
dL_prior = PowerLaw(1.0, 75.0, 2.0, naming=["d_L"])
t_c_prior = Uniform(-0.1, 0.1, naming=["t_c"])
phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"])
cos_iota_prior = Uniform(
-1.0,
1.0,
Expand All @@ -61,8 +68,8 @@
)
},
)
psi_prior = Uniform(0.0, jnp.pi, naming=["psi"])
ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"])
psi_prior = Uniform(0.0, jnp.pi, naming=["psi"])
ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"])
sin_dec_prior = Uniform(
-1.0,
1.0,
Expand All @@ -77,7 +84,8 @@
},
)

prior = Composite([
prior = Composite(
[
Mc_prior,
q_prior,
s1z_prior,
Expand All @@ -96,19 +104,27 @@
bounds = jnp.array([[p.xmin, p.xmax] for p in prior.priors])

### Create likelihood object
likelihood = HeterodynedTransientLikelihoodFD([H1, L1, V1], prior=prior, bounds=bounds, waveform=RippleIMRPhenomD(), trigger_time=gps, duration=T, n_bins=500)
likelihood = HeterodynedTransientLikelihoodFD(
[H1, L1, V1],
prior=prior,
bounds=bounds,
waveform=RippleIMRPhenomD(),
trigger_time=gps,
duration=T,
n_bins=500,
)

### Create sampler and jim objects
eps = 3e-2
n_dim = 11
mass_matrix = jnp.eye(n_dim)
mass_matrix = mass_matrix.at[0,0].set(1e-5)
mass_matrix = mass_matrix.at[1,1].set(1e-4)
mass_matrix = mass_matrix.at[2,2].set(1e-3)
mass_matrix = mass_matrix.at[3,3].set(1e-3)
mass_matrix = mass_matrix.at[5,5].set(1e-5)
mass_matrix = mass_matrix.at[9,9].set(1e-2)
mass_matrix = mass_matrix.at[10,10].set(1e-2)
mass_matrix = mass_matrix.at[0, 0].set(1e-5)
mass_matrix = mass_matrix.at[1, 1].set(1e-4)
mass_matrix = mass_matrix.at[2, 2].set(1e-3)
mass_matrix = mass_matrix.at[3, 3].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-5)
mass_matrix = mass_matrix.at[9, 9].set(1e-2)
mass_matrix = mass_matrix.at[10, 10].set(1e-2)
local_sampler_arg = {"step_size": mass_matrix * eps}

outdir_name = "./outdir/"
Expand All @@ -129,11 +145,11 @@
use_global=True,
keep_quantile=0.0,
train_thinning=10,
output_thinning=30,
n_loops_maximize_likelihood = 2000,
output_thinning=30,
n_loops_maximize_likelihood=2000,
local_sampler_arg=local_sampler_arg,
outdir_name=outdir_name
outdir_name=outdir_name,
)

jim.sample(jax.random.PRNGKey(42))
jim.print_summary()
jim.print_summary()
6 changes: 3 additions & 3 deletions example/InjectionRecovery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jimgw.jim import Jim
from jimgw.detector import H1, L1, V1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomPv2
from jimgw.single_event.detector import H1, L1, V1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomPv2
from jimgw.prior import Uniform
from ripple import ms_to_Mc_eta
import jax.numpy as jnp
Expand Down
91 changes: 91 additions & 0 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@

from jimgw.single_event.runManager import SingleEventPERunManager, SingleEventRun
import jax.numpy as jnp
import jax

jax.config.update("jax_enable_x64", True)

mass_matrix = jnp.eye(11)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}
bounds = jnp.array(
[
[10.0, 40.0],
[0.125, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
[0.0, 2000.0],
[-0.05, 0.05],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
[0.0, jnp.pi],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
]
)


run = SingleEventRun(
seed=0,
path="test_data/GW150914/",
detectors=["H1", "L1"],
priors={
"M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "MassRatio"},
"s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"cos_iota": {"name": "CosIota"},
"psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"sin_dec": {"name": "SinDec"},
},
waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
jim_parameters={
"n_loop_training": 10,
"n_loop_production": 10,
"n_local_steps": 150,
"n_global_steps": 150,
"n_chains": 500,
"n_epochs": 50,
"learning_rate": 0.001,
"max_samples": 45000,
"momentum": 0.9,
"batch_size": 50000,
"use_global": True,
"keep_quantile": 0.0,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds},
injection=True,
injection_parameters={
"M_c": 28.6,
"eta": 0.24,
"s1_z": 0.05,
"s2_z": 0.05,
"d_L": 440.0,
"t_c": 0.0,
"phase_c": 0.0,
"iota": 0.5,
"psi": 0.7,
"ra": 1.2,
"dec": 0.3,
},
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
},
)

run_manager = SingleEventPERunManager(run=run)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ packages = find:
install_requires =
jax>=0.4.12
jaxlib>=0.4.12
flowMC>=0.2.1
flowMC>=0.2.4
ripplegw
gwpy
corner
Expand Down
Loading

0 comments on commit de17d53

Please sign in to comment.