Skip to content

Commit

Permalink
Merge pull request #132 from ThibeauWouters/98-moving-naming-tracking…
Browse files Browse the repository at this point in the history
…-into-jim-class-from-prior-class

Added code to reverse transforms for more flexibility
  • Loading branch information
kazewong authored Sep 2, 2024
2 parents 7910785 + ad90d09 commit 0e96439
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 162 deletions.
222 changes: 73 additions & 149 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from astropy.time import Time

from jimgw.single_event.detector import GroundBased2G
from jimgw.transforms import BijectiveTransform, NtoNTransform
from jimgw.transforms import (
BijectiveTransform,
NtoNTransform,
reverse_bijective_transform,
)
from jimgw.single_event.utils import (
m1_m2_to_Mc_q,
Mc_q_to_m1_m2,
Expand All @@ -20,111 +24,56 @@


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform):
"""
Transform chirp mass and mass ratio to component masses
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
class SpinToCartesianSpinTransform(NtoNTransform):
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "q" in name_mapping[1]
)

def named_transform(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform):
Spin to Cartesian spin transformation
"""
Transform mass ratio to symmetric mass ratio

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""
freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "eta" in name_mapping[1]
name_mapping = (
["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"],
["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"],
)
super().__init__(name_mapping)

self.freq_ref = freq_ref

def named_transform(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class MassRatioToSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0]

self.transform_func = lambda x: {"eta": q_to_eta(x["q"])}
self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])}


@jaxtyped(typechecker=typechecker)
class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):
"""
Transform sky frame to detector frame sky position
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

gmst: Float
Expand All @@ -133,10 +82,10 @@ class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
gps_time: Float,
ifos: GroundBased2G,
):
name_mapping = (["ra", "dec"], ["zenith", "azimuth"])
super().__init__(name_mapping)

self.gmst = (
Expand All @@ -146,13 +95,6 @@ def __init__(
self.rotation = euler_rotation(delta_x)
self.rotation_inv = jnp.linalg.inv(self.rotation)

assert (
"ra" in name_mapping[0]
and "dec" in name_mapping[0]
and "zenith" in name_mapping[1]
and "azimuth" in name_mapping[1]
)

def named_transform(x):
zenith, azimuth = ra_dec_to_zenith_azimuth(
x["ra"], x["dec"], self.gmst, self.rotation
Expand All @@ -169,63 +111,45 @@ def named_inverse_transform(x):

self.inverse_transform_func = named_inverse_transform

def named_m1_m2_to_Mc_q(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

@jaxtyped(typechecker=typechecker)
class SpinToCartesianSpinTransform(NtoNTransform):
"""
Spin to Cartesian spin transformation
"""
def named_Mc_q_to_m1_m2(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

freq_ref: Float
ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"]))
ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q
ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)
def named_m1_m2_to_Mc_eta(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}

self.freq_ref = freq_ref
def named_Mc_eta_to_m1_m2(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"])
return {"m_1": m1, "m_2": m2}

assert (
"theta_jn" in name_mapping[0]
and "phi_jl" in name_mapping[0]
and "theta_1" in name_mapping[0]
and "theta_2" in name_mapping[0]
and "phi_12" in name_mapping[0]
and "a_1" in name_mapping[0]
and "a_2" in name_mapping[0]
and "iota" in name_mapping[1]
and "s1_x" in name_mapping[1]
and "s1_y" in name_mapping[1]
and "s1_z" in name_mapping[1]
and "s2_x" in name_mapping[1]
and "s2_y" in name_mapping[1]
and "s2_z" in name_mapping[1]
)
ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"]))
ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta
ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2

def named_transform(x):
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}
def named_q_to_eta(x):
return {"eta": q_to_eta(x["q"])}
def named_eta_to_q(x):
return {"q": eta_to_q(x["eta"])}
MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"]))
MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta
MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q

self.transform_func = named_transform

ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassMassRatioTransform
)
ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassSymmetricMassRatioTransform
)
SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform(
MassRatioToSymmetricMassRatioTransform
)
16 changes: 16 additions & 0 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,19 @@ def __init__(
)
for i in range(len(name_mapping[1]))
}


def reverse_bijective_transform(
original_transform: BijectiveTransform,
) -> BijectiveTransform:

reversed_name_mapping = (
original_transform.name_mapping[1],
original_transform.name_mapping[0],
)
reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping)
reversed_transform.transform_func = original_transform.inverse_transform_func
reversed_transform.inverse_transform_func = original_transform.transform_func
reversed_transform.__repr__ = lambda: f"Reversed{repr(original_transform)}"

return reversed_transform
2 changes: 2 additions & 0 deletions test/integration/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
outdir/
figures/
16 changes: 10 additions & 6 deletions test/integration/test_GW150914_D.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10"

import jax
import jax.numpy as jnp

Expand All @@ -10,6 +14,8 @@
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam
from flowMC.utils.postprocessing import plot_summary
import optax

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -62,7 +68,7 @@
)

sample_transforms = [
ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]),
ComponentMassesToChirpMassMassRatioTransform,
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
Expand All @@ -72,13 +78,13 @@
BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]),
ComponentMassesToChirpMassSymmetricMassRatioTransform,
]

likelihood = TransientLikelihoodFD(
Expand All @@ -89,7 +95,6 @@
post_trigger_duration=2,
)


mass_matrix = jnp.eye(11)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
Expand All @@ -101,7 +106,6 @@
n_loop_training = 1
learning_rate = 1e-4


jim = Jim(
likelihood,
prior,
Expand All @@ -127,4 +131,4 @@

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()
jim.print_summary()
Loading

0 comments on commit 0e96439

Please sign in to comment.