Skip to content

Commit

Permalink
Merge pull request #147 from kazewong/jim-dev
Browse files Browse the repository at this point in the history
Sync
  • Loading branch information
CharmaineWONG2 authored Sep 10, 2024
2 parents 1b79a3a + d2c0416 commit d657eb5
Show file tree
Hide file tree
Showing 11 changed files with 515 additions and 174 deletions.
23 changes: 14 additions & 9 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,20 @@
]
)


run = SingleEventRun(
seed=0,
detectors=["H1", "L1"],
priors={
"M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
"M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "MassRatio"},
"s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"cos_iota": {"name": "CosIota"},
"psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"sin_dec": {"name": "SinDec"},
},
waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
Expand Down Expand Up @@ -90,3 +89,9 @@
)

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))

# plot the corner plot and diagnostic plot
run_manager.plot_corner()
run_manager.plot_diagnostic()

79 changes: 78 additions & 1 deletion src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import corner
import numpy as np
import yaml
from astropy.time import Time
from jaxlib.xla_extension import ArrayImpl
Expand Down Expand Up @@ -71,7 +73,8 @@ class SingleEventRun:
str, dict[str, Union[str, float, int, bool]]
] # Transform cannot be included in this way, add it to preset if used often.
jim_parameters: dict[str, Union[str, float, int, bool, dict]]
injection_parameters: dict[str, float]
path: str = "./experiment"
injection_parameters: dict[str, float] = field(default_factory=lambda: {})
injection: bool = False
likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field(
default_factory=lambda: {"name": "TransientLikelihoodFD"}
Expand Down Expand Up @@ -123,6 +126,9 @@ def __init__(self, **kwargs):
print("Neither run instance nor path provided.")
raise ValueError

if self.run.injection and not self.run.injection_parameters:
raise ValueError("Injection mode requires injection parameters.")

local_prior = self.initialize_prior()
local_likelihood = self.initialize_likelihood(local_prior)
self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters)
Expand Down Expand Up @@ -150,6 +156,7 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
waveform = self.initialize_waveform()
name = self.run.likelihood_parameters["name"]
assert isinstance(name, str), "Likelihood name must be a string."
assert name in likelihood_presets, f"Likelihood {name} not recognized."
if self.run.injection:
freqs = jnp.linspace(
self.run.data_parameters["f_min"],
Expand Down Expand Up @@ -351,3 +358,73 @@ def plot_data(self, path: str):
plt.ylabel("Amplitude")
plt.legend()
plt.savefig(path)

def sample(self):
self.jim.sample(jax.random.PRNGKey(self.run.seed))

def get_samples(self):
return self.jim.get_samples()

def plot_corner(self, path: str = "corner.jpeg", **kwargs):
"""
plot corner plot of the samples.
"""
plot_datapoint = kwargs.get("plot_datapoints", False)
title_quantiles = kwargs.get("title_quantiles", [0.16, 0.5, 0.84])
show_titles = kwargs.get("show_titles", True)
title_fmt = kwargs.get("title_fmt", ".2E")
use_math_text = kwargs.get("use_math_text", True)

samples = self.jim.get_samples()
param_names = list(samples.keys())
samples = np.array(list(samples.values())).reshape(int(len(param_names)), -1).T
corner.corner(
samples,
labels=param_names,
plot_datapoints=plot_datapoint,
title_quantiles=title_quantiles,
show_titles=show_titles,
title_fmt=title_fmt,
use_math_text=use_math_text,
**kwargs,
)
plt.savefig(path)
plt.close()

def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
"""
plot diagnostic plot of the samples.
"""
summary = self.jim.Sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values()
log_prob = np.array(log_prob)

plt.figure(figsize=(10, 10))
axs = [plt.subplot(2, 2, i + 1) for i in range(4)]
plt.sca(axs[0])
plt.title("log probability")
plt.plot(log_prob.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[1])
plt.title("NF loss")
plt.plot(loss_vals.reshape(-1))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[2])
plt.title("Local Acceptance")
plt.plot(local_accs.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[3])
plt.title("Global Acceptance")
plt.plot(global_accs.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)
plt.tight_layout()

plt.savefig(path)
plt.close()
Loading

0 comments on commit d657eb5

Please sign in to comment.