Skip to content

Commit

Permalink
Merge pull request #78 from kazewong/76-sync-with-flowmc-030-api
Browse files Browse the repository at this point in the history
76 sync with flowmc 030 api
  • Loading branch information
kazewong authored May 7, 2024
2 parents a878868 + 07cb1ce commit 4aecf6a
Show file tree
Hide file tree
Showing 25 changed files with 239 additions and 487 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
45 changes: 29 additions & 16 deletions example/GW150914.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import time

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import Composite, Unconstrained_Uniform
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import (
HeterodynedTransientLikelihoodFD,
TransientLikelihoodFD,
)
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.prior import Unconstrained_Uniform, Composite
import jax.numpy as jnp
import jax
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -102,24 +102,37 @@
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)

import optax
n_epochs = 20
n_loop_training = 100
total_epochs = n_epochs * n_loop_training
start = total_epochs//10
learning_rate = optax.polynomial_schedule(
1e-3, 1e-4, 400, total_epochs - start, transition_begin=start
)


jim = Jim(
likelihood,
prior,
n_loop_training=100,
n_loop_production=10,
n_local_steps=150,
n_global_steps=150,
n_loop_training=n_loop_training,
n_loop_production=20,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=50,
learning_rate=0.001,
max_samples=45000,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_samples=100000,
momentum=0.9,
batch_size=50000,
batch_size=30000,
use_global=True,
keep_quantile=0.0,
train_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer,"default"],
)

jim.sample(jax.random.PRNGKey(42))
85 changes: 53 additions & 32 deletions example/GW150914_PV2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import time

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import Composite, Sphere, Unconstrained_Uniform
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomPv2
from jimgw.prior import Uniform, Composite, Sphere
import jax.numpy as jnp
import jax
from flowMC.strategy.optimization import optimization_Adam


jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -33,19 +37,19 @@
########## Set up priors ##################
###########################################

Mc_prior = Uniform(10.0, 80.0, naming=["M_c"])
q_prior = Uniform(
Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"])
q_prior = Unconstrained_Uniform(
0.125,
1.0,
naming=["q"],
transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)},
)
s1_prior = Sphere(naming="s1")
s2_prior = Sphere(naming="s2")
dL_prior = Uniform(0.0, 2000.0, naming=["d_L"])
t_c_prior = Uniform(-0.05, 0.05, naming=["t_c"])
phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"])
cos_iota_prior = Uniform(
dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"])
t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"])
phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"])
cos_iota_prior = Unconstrained_Uniform(
-1.0,
1.0,
naming=["cos_iota"],
Expand All @@ -58,9 +62,9 @@
)
},
)
psi_prior = Uniform(0.0, jnp.pi, naming=["psi"])
ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"])
sin_dec_prior = Uniform(
psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"])
ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"])
sin_dec_prior = Unconstrained_Uniform(
-1.0,
1.0,
naming=["sin_dec"],
Expand Down Expand Up @@ -90,55 +94,72 @@
],
)

epsilon = 1e-3
bounds = jnp.array(
[
[10.0, 80.0],
[0.125, 1.0],
[0, jnp.pi],
[0, 2*jnp.pi],
[0, 2 * jnp.pi],
[0.0, 1.0],
[0, jnp.pi],
[0, 2*jnp.pi],
[0, 2 * jnp.pi],
[0.0, 1.0],
[0.0, 2000.0],
[0.0, 2000],
[-0.05, 0.05],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
[0.0, jnp.pi],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
]
) + jnp.array([[epsilon, -epsilon]])

likelihood = TransientLikelihoodFD(
[H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2
)
likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)
# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)


mass_matrix = jnp.eye(prior.n_dim)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[9, 9].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}
local_sampler_arg = {"step_size": mass_matrix * 1e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1, bounds=bounds)

import optax
n_epochs = 20
n_loop_training = 100
total_epochs = n_epochs * n_loop_training
start = total_epochs//10
learning_rate = optax.polynomial_schedule(
1e-3, 1e-4, 400, total_epochs - start, transition_begin=start
)

jim = Jim(
likelihood,
prior,
n_loop_training=200,
n_loop_production=10,
n_local_steps=300,
n_global_steps=300,
n_loop_training=n_loop_training,
n_loop_production=20,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=300,
learning_rate=0.001,
max_samples = 10000,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_sample=100000,
momentum=0.9,
batch_size=10000,
batch_size=30000,
use_global=True,
keep_quantile=0.,
keep_quantile=0.0,
train_thinning=1,
output_thinning=30,
num_layers=6,
hidden_size=[64, 64],
num_bins=8,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
# strategies=[Adam_optimizer,"default"],
)

jim.sample(jax.random.PRNGKey(42))
import numpy as np
# chains = np.load('./GW150914_init.npz')['chain']

jim.sample(jax.random.PRNGKey(42))#,initial_guess=chains)
152 changes: 0 additions & 152 deletions example/GW150914_PV2_newglobal.py

This file was deleted.

Loading

0 comments on commit 4aecf6a

Please sign in to comment.