diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 832c84a9..cb60e98b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -111,77 +111,38 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +def named_m1_m2_to_Mc_q(x): + Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) + return {"M_c": Mc, "q": q} + +def named_Mc_q_to_m1_m2(x): + m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} + +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) +ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 + +def named_m1_m2_to_Mc_eta(x): + Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) + return {"M_c": Mc, "eta": eta} + +def named_Mc_eta_to_m1_m2(x): + m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) + return {"m_1": m1, "m_2": m2} + +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 + +def named_q_to_eta(x): + return {"eta": q_to_eta(x["q"])} +def named_eta_to_q(x): + return {"q": eta_to_q(x["eta"])} +MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) +MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta +MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q -@jaxtyped(typechecker=typechecker) -class _ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): - """ - Transform chirp mass and mass ratio to component masses. Instantiated to object below. - """ - - def __init__(self): - name_mapping = (["m_1", "m_2"], ["M_c", "q"]) - super().__init__(name_mapping) - - def named_transform(x): - Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) - return {"M_c": Mc, "q": q} - - self.transform_func = named_transform - - def named_inverse_transform(x): - m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class _ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform): - """ - Transform mass ratio to symmetric mass ratio. Instantiated to object below. - """ - - def __init__(self): - name_mapping = (["m_1", "m_2"], ["M_c", "eta"]) - super().__init__(name_mapping) - - def named_transform(x): - Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) - return {"M_c": Mc, "eta": eta} - - self.transform_func = named_transform - - def named_inverse_transform(x): - m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class _MassRatioToSymmetricMassRatioTransform(BijectiveTransform): - """ - Transform mass ratio to symmetric mass ratio. Instantiated to object below. - """ - - def __init__(self): - name_mapping = (["q"], ["eta"]) - super().__init__(name_mapping) - - self.transform_func = lambda x: {"eta": q_to_eta(x["q"])} - self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])} - - def __repr__(self): - return f"{self.__class__.__name__[1:]}()" - - -ComponentMassesToChirpMassMassRatioTransform = ( - _ComponentMassesToChirpMassMassRatioTransform() -) -ComponentMassesToChirpMassSymmetricMassRatioTransform = ( - _ComponentMassesToChirpMassSymmetricMassRatioTransform() -) -MassRatioToSymmetricMassRatioTransform = _MassRatioToSymmetricMassRatioTransform() ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform( ComponentMassesToChirpMassMassRatioTransform