Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

76 sync with flowmc 030 api #78

Merged
merged 14 commits into from
May 7, 2024
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
45 changes: 29 additions & 16 deletions example/GW150914.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import time

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import Composite, Unconstrained_Uniform
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import (
HeterodynedTransientLikelihoodFD,
TransientLikelihoodFD,
)
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.prior import Unconstrained_Uniform, Composite
import jax.numpy as jnp
import jax
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -102,24 +102,37 @@
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)

import optax
n_epochs = 20
n_loop_training = 100
total_epochs = n_epochs * n_loop_training
start = total_epochs//10
learning_rate = optax.polynomial_schedule(
1e-3, 1e-4, 400, total_epochs - start, transition_begin=start
)


jim = Jim(
likelihood,
prior,
n_loop_training=100,
n_loop_production=10,
n_local_steps=150,
n_global_steps=150,
n_loop_training=n_loop_training,
n_loop_production=20,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=50,
learning_rate=0.001,
max_samples=45000,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_samples=100000,
momentum=0.9,
batch_size=50000,
batch_size=30000,
use_global=True,
keep_quantile=0.0,
train_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer,"default"],
)

jim.sample(jax.random.PRNGKey(42))
85 changes: 53 additions & 32 deletions example/GW150914_PV2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import time

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import Composite, Sphere, Unconstrained_Uniform
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomPv2
from jimgw.prior import Uniform, Composite, Sphere
import jax.numpy as jnp
import jax
from flowMC.strategy.optimization import optimization_Adam


jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -33,19 +37,19 @@
########## Set up priors ##################
###########################################

Mc_prior = Uniform(10.0, 80.0, naming=["M_c"])
q_prior = Uniform(
Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"])
q_prior = Unconstrained_Uniform(
0.125,
1.0,
naming=["q"],
transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)},
)
s1_prior = Sphere(naming="s1")
s2_prior = Sphere(naming="s2")
dL_prior = Uniform(0.0, 2000.0, naming=["d_L"])
t_c_prior = Uniform(-0.05, 0.05, naming=["t_c"])
phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"])
cos_iota_prior = Uniform(
dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"])
t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"])
phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"])
cos_iota_prior = Unconstrained_Uniform(
-1.0,
1.0,
naming=["cos_iota"],
Expand All @@ -58,9 +62,9 @@
)
},
)
psi_prior = Uniform(0.0, jnp.pi, naming=["psi"])
ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"])
sin_dec_prior = Uniform(
psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"])
ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"])
sin_dec_prior = Unconstrained_Uniform(
-1.0,
1.0,
naming=["sin_dec"],
Expand Down Expand Up @@ -90,55 +94,72 @@
],
)

epsilon = 1e-3
bounds = jnp.array(
[
[10.0, 80.0],
[0.125, 1.0],
[0, jnp.pi],
[0, 2*jnp.pi],
[0, 2 * jnp.pi],
[0.0, 1.0],
[0, jnp.pi],
[0, 2*jnp.pi],
[0, 2 * jnp.pi],
[0.0, 1.0],
[0.0, 2000.0],
[0.0, 2000],
[-0.05, 0.05],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
[0.0, jnp.pi],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
]
) + jnp.array([[epsilon, -epsilon]])

likelihood = TransientLikelihoodFD(
[H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2
)
likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)
# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)


mass_matrix = jnp.eye(prior.n_dim)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[9, 9].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}
local_sampler_arg = {"step_size": mass_matrix * 1e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1, bounds=bounds)

import optax
n_epochs = 20
n_loop_training = 100
total_epochs = n_epochs * n_loop_training
start = total_epochs//10
learning_rate = optax.polynomial_schedule(
1e-3, 1e-4, 400, total_epochs - start, transition_begin=start
)

jim = Jim(
likelihood,
prior,
n_loop_training=200,
n_loop_production=10,
n_local_steps=300,
n_global_steps=300,
n_loop_training=n_loop_training,
n_loop_production=20,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=300,
learning_rate=0.001,
max_samples = 10000,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_sample=100000,
momentum=0.9,
batch_size=10000,
batch_size=30000,
use_global=True,
keep_quantile=0.,
keep_quantile=0.0,
train_thinning=1,
output_thinning=30,
num_layers=6,
hidden_size=[64, 64],
num_bins=8,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
# strategies=[Adam_optimizer,"default"],
)

jim.sample(jax.random.PRNGKey(42))
import numpy as np
# chains = np.load('./GW150914_init.npz')['chain']

jim.sample(jax.random.PRNGKey(42))#,initial_guess=chains)
152 changes: 0 additions & 152 deletions example/GW150914_PV2_newglobal.py

This file was deleted.

Loading
Loading