Skip to content

Commit

Permalink
Merge branch 'jim-dev' of github.com:kazewong/JaxGW into 98-moving-na…
Browse files Browse the repository at this point in the history
…ming-tracking-into-jim-class-from-prior-class
  • Loading branch information
kazewong committed Sep 2, 2024
2 parents 0e96439 + a46af6b commit b003785
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 10 deletions.
23 changes: 14 additions & 9 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,20 @@
]
)


run = SingleEventRun(
seed=0,
detectors=["H1", "L1"],
priors={
"M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
"M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "MassRatio"},
"s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"cos_iota": {"name": "CosIota"},
"psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"sin_dec": {"name": "SinDec"},
},
waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
Expand Down Expand Up @@ -90,3 +89,9 @@
)

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))

# plot the corner plot and diagnostic plot
run_manager.plot_corner()
run_manager.plot_diagnostic()

79 changes: 78 additions & 1 deletion src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import corner
import numpy as np
import yaml
from astropy.time import Time
from jaxlib.xla_extension import ArrayImpl
Expand Down Expand Up @@ -71,7 +73,8 @@ class SingleEventRun:
str, dict[str, Union[str, float, int, bool]]
] # Transform cannot be included in this way, add it to preset if used often.
jim_parameters: dict[str, Union[str, float, int, bool, dict]]
injection_parameters: dict[str, float]
path: str = "./experiment"
injection_parameters: dict[str, float] = field(default_factory=lambda: {})
injection: bool = False
likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field(
default_factory=lambda: {"name": "TransientLikelihoodFD"}
Expand Down Expand Up @@ -123,6 +126,9 @@ def __init__(self, **kwargs):
print("Neither run instance nor path provided.")
raise ValueError

if self.run.injection and not self.run.injection_parameters:
raise ValueError("Injection mode requires injection parameters.")

local_prior = self.initialize_prior()
local_likelihood = self.initialize_likelihood(local_prior)
self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters)
Expand Down Expand Up @@ -150,6 +156,7 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
waveform = self.initialize_waveform()
name = self.run.likelihood_parameters["name"]
assert isinstance(name, str), "Likelihood name must be a string."
assert name in likelihood_presets, f"Likelihood {name} not recognized."
if self.run.injection:
freqs = jnp.linspace(
self.run.data_parameters["f_min"],
Expand Down Expand Up @@ -351,3 +358,73 @@ def plot_data(self, path: str):
plt.ylabel("Amplitude")
plt.legend()
plt.savefig(path)

def sample(self):
self.jim.sample(jax.random.PRNGKey(self.run.seed))

def get_samples(self):
return self.jim.get_samples()

def plot_corner(self, path: str = "corner.jpeg", **kwargs):
"""
plot corner plot of the samples.
"""
plot_datapoint = kwargs.get("plot_datapoints", False)
title_quantiles = kwargs.get("title_quantiles", [0.16, 0.5, 0.84])
show_titles = kwargs.get("show_titles", True)
title_fmt = kwargs.get("title_fmt", ".2E")
use_math_text = kwargs.get("use_math_text", True)

samples = self.jim.get_samples()
param_names = list(samples.keys())
samples = np.array(list(samples.values())).reshape(int(len(param_names)), -1).T
corner.corner(
samples,
labels=param_names,
plot_datapoints=plot_datapoint,
title_quantiles=title_quantiles,
show_titles=show_titles,
title_fmt=title_fmt,
use_math_text=use_math_text,
**kwargs,
)
plt.savefig(path)
plt.close()

def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
"""
plot diagnostic plot of the samples.
"""
summary = self.jim.Sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values()
log_prob = np.array(log_prob)

plt.figure(figsize=(10, 10))
axs = [plt.subplot(2, 2, i + 1) for i in range(4)]
plt.sca(axs[0])
plt.title("log probability")
plt.plot(log_prob.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[1])
plt.title("NF loss")
plt.plot(loss_vals.reshape(-1))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[2])
plt.title("Local Acceptance")
plt.plot(local_accs.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[3])
plt.title("Global Acceptance")
plt.plot(global_accs.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)
plt.tight_layout()

plt.savefig(path)
plt.close()
159 changes: 159 additions & 0 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,165 @@ def eta_to_q(eta: Float) -> Float:
return temp - (temp**2 - 1) ** 0.5


def spin_to_cartesian_spin(
thetaJN: Float,
phiJL: Float,
theta1: Float,
theta2: Float,
phi12: Float,
chi1: Float,
chi2: Float,
M_c: Float,
eta: Float,
fRef: Float,
phiRef: Float,
) -> tuple[Float, Float, Float, Float, Float, Float, Float]:
"""
Transforming the spin parameters
The code is based on the approach used in LALsimulation:
https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html
Parameters:
-------
thetaJN: Float
Zenith angle between the total angular momentum and the line of sight
phiJL: Float
Difference between total and orbital angular momentum azimuthal angles
theta1: Float
Zenith angle between the spin and orbital angular momenta for the primary object
theta2: Float
Zenith angle between the spin and orbital angular momenta for the secondary object
phi12: Float
Difference between the azimuthal angles of the individual spin vector projections
onto the orbital plane
chi1: Float
Primary object aligned spin:
chi2: Float
Secondary object aligned spin:
M_c: Float
The chirp mass
eta: Float
The symmetric mass ratio
fRef: Float
The reference frequency
phiRef: Float
Binary phase at a reference frequency
Returns:
-------
iota: Float
Zenith angle between the orbital angular momentum and the line of sight
S1x: Float
The x-component of the primary spin
S1y: Float
The y-component of the primary spin
S1z: Float
The z-component of the primary spin
S2x: Float
The x-component of the secondary spin
S2y: Float
The y-component of the secondary spin
S2z: Float
The z-component of the secondary spin
"""

def rotate_y(angle, vec):
"""
Rotate the vector (x, y, z) about y-axis
"""
cos_angle = jnp.cos(angle)
sin_angle = jnp.sin(angle)
rotation_matrix = jnp.array(
[[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]]
)
rotated_vec = jnp.dot(rotation_matrix, vec)
return rotated_vec

def rotate_z(angle, vec):
"""
Rotate the vector (x, y, z) about z-axis
"""
cos_angle = jnp.cos(angle)
sin_angle = jnp.sin(angle)
rotation_matrix = jnp.array(
[[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]]
)
rotated_vec = jnp.dot(rotation_matrix, vec)
return rotated_vec

LNh = jnp.array([0.0, 0.0, 1.0])

s1hat = jnp.array(
[
jnp.sin(theta1) * jnp.cos(phiRef),
jnp.sin(theta1) * jnp.sin(phiRef),
jnp.cos(theta1),
]
)
s2hat = jnp.array(
[
jnp.sin(theta2) * jnp.cos(phi12 + phiRef),
jnp.sin(theta2) * jnp.sin(phi12 + phiRef),
jnp.cos(theta2),
]
)

temp = 1 / eta / 2 - 1
q = temp - (temp**2 - 1) ** 0.5
m1, m2 = Mc_q_to_m1m2(M_c, q)
v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef)

Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0))
s1 = m1 * m1 * chi1 * s1hat
s2 = m2 * m2 * chi2 * s2hat
J = s1 + s2 + jnp.array([0.0, 0.0, Lmag])

Jhat = J / jnp.linalg.norm(J)
theta0 = jnp.arccos(Jhat[2])
phi0 = jnp.arctan2(Jhat[1], Jhat[0])

# Rotation 1:
s1hat = rotate_z(-phi0, s1hat)
s2hat = rotate_z(-phi0, s2hat)

# Rotation 2:
LNh = rotate_y(-theta0, LNh)
s1hat = rotate_y(-theta0, s1hat)
s2hat = rotate_y(-theta0, s2hat)

# Rotation 3:
LNh = rotate_z(phiJL - jnp.pi, LNh)
s1hat = rotate_z(phiJL - jnp.pi, s1hat)
s2hat = rotate_z(phiJL - jnp.pi, s2hat)

# Compute iota
N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)])
iota = jnp.arccos(jnp.dot(N, LNh))

thetaLJ = jnp.arccos(LNh[2])
phiL = jnp.arctan2(LNh[1], LNh[0])

# Rotation 4:
s1hat = rotate_z(-phiL, s1hat)
s2hat = rotate_z(-phiL, s2hat)
N = rotate_z(-phiL, N)

# Rotation 5:
s1hat = rotate_y(-thetaLJ, s1hat)
s2hat = rotate_y(-thetaLJ, s2hat)
N = rotate_y(-thetaLJ, N)

# Rotation 6:
phiN = jnp.arctan2(N[1], N[0])
s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat)
s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat)

S1 = s1hat * chi1
S2 = s2hat * chi2
return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2]


def euler_rotation(delta_x: Float[Array, " 3"]):
"""
Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x
Expand Down
48 changes: 48 additions & 0 deletions test/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
import jax.numpy as jnp

class TestTransform:
def test_sky_location_transform(self):
from bilby.gw.utils import zenith_azimuth_to_ra_dec as bilby_earth_to_sky
from bilby.gw.detector.networks import InterferometerList

from jimgw.single_event.utils import zenith_azimuth_to_ra_dec as jimgw_earth_to_sky
from jimgw.single_event.detector import detector_preset
from astropy.time import Time

ifos = ["H1", "L1"]
geocent_time = 1000000000

import matplotlib.pyplot as plt

for zenith in np.linspace(0, np.pi, 10):
for azimuth in np.linspace(0, 2*np.pi, 10):
bilby_sky_location = np.array(bilby_earth_to_sky(zenith, azimuth, geocent_time, InterferometerList(ifos)))
jimgw_sky_location = np.array(jimgw_earth_to_sky(zenith, azimuth, Time(geocent_time, format="gps").sidereal_time("apparent", "greenwich").rad, detector_preset[ifos[0]].vertex - detector_preset[ifos[1]].vertex))
assert np.allclose(bilby_sky_location, jimgw_sky_location, atol=1e-4)

def test_spin_transform(self):
from bilby.gw.conversion import bilby_to_lalsimulation_spins as bilby_spin_transform
from bilby.gw.conversion import symmetric_mass_ratio_to_mass_ratio, chirp_mass_and_mass_ratio_to_component_masses

from jimgw.single_event.utils import spin_to_cartesian_spin as jimgw_spin_transform

for _ in range(100):
thetaJN = jnp.array(np.random.uniform(0, np.pi))
phiJL = jnp.array(np.random.uniform(0, np.pi))
theta1 = jnp.array(np.random.uniform(0, np.pi))
theta2 = jnp.array(np.random.uniform(0, np.pi))
phi12 = jnp.array(np.random.uniform(0, np.pi))
chi1 = jnp.array(np.random.uniform(0, 1))
chi2 = jnp.array(np.random.uniform(0, 1))
M_c = jnp.array(np.random.uniform(1, 100))
eta = jnp.array(np.random.uniform(0.1, 0.25))
fRef = jnp.array(np.random.uniform(10, 1000))
phiRef = jnp.array(np.random.uniform(0, 2*np.pi))

q = symmetric_mass_ratio_to_mass_ratio(eta)
m1, m2 = chirp_mass_and_mass_ratio_to_component_masses(M_c, q)
MsunInkg = 1.9884e30
bilby_spin = jnp.array(bilby_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, m1*MsunInkg, m2*MsunInkg, fRef, phiRef))
jimgw_spin = jnp.array(jimgw_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, M_c, eta, fRef, phiRef))
assert np.allclose(bilby_spin, jimgw_spin, atol=1e-4)

0 comments on commit b003785

Please sign in to comment.