diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index da7864ba..62a09fd6 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -140,13 +140,13 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform( def __init__( self, - name_mapping: tuple[list[str], list[str]], - conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, tc_min: Float, tc_max: Float, ): + name_mapping = [["t_c"], ["t_det_unbounded"]] + conditional_names = ["ra", "dec"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -225,11 +225,11 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( def __init__( self, - name_mapping: tuple[list[str], list[str]], - conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): + name_mapping = [["phase_c"], ["phase_det"]] + conditional_names = ["ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -305,6 +305,8 @@ def __init__( dL_min: Float, dL_max: Float, ): + name_mapping = [["d_L"], ["d_hat_unbounded"]] + conditional_names = ["M_c", "ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -381,96 +383,54 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform -@jaxtyped(typechecker=typechecker) -class SpinToCartesianSpinTransform(NtoNTransform): - """ - Spin to Cartesian spin transformation - """ - - freq_ref: Float - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - freq_ref: Float, - ): - super().__init__(name_mapping) - - self.freq_ref = freq_ref - - assert ( - "theta_jn" in name_mapping[0] - and "phi_jl" in name_mapping[0] - and "theta_1" in name_mapping[0] - and "theta_2" in name_mapping[0] - and "phi_12" in name_mapping[0] - and "a_1" in name_mapping[0] - and "a_2" in name_mapping[0] - and "iota" in name_mapping[1] - and "s1_x" in name_mapping[1] - and "s1_y" in name_mapping[1] - and "s1_z" in name_mapping[1] - and "s2_x" in name_mapping[1] - and "s2_y" in name_mapping[1] - and "s2_z" in name_mapping[1] - ) - - def named_transform(x): - iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( - x["theta_jn"], - x["phi_jl"], - x["theta_1"], - x["theta_2"], - x["phi_12"], - x["a_1"], - x["a_2"], - x["M_c"], - x["q"], - self.freq_ref, - x["phase_c"], - ) - return { - "iota": iota, - "s1_x": s1x, - "s1_y": s1y, - "s1_z": s1z, - "s2_x": s2x, - "s2_y": s2y, - "s2_z": s2z, - } - - self.transform_func = named_transform - - def named_m1_m2_to_Mc_q(x): Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) - return {"M_c": Mc, "q": q} - - + return {"M_c": Mc, "q": q} + + def named_Mc_q_to_m1_m2(x): m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) + +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "q"]) +) ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q -ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = ( + named_Mc_q_to_m1_m2 +) + def named_m1_m2_to_Mc_eta(x): Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) return {"M_c": Mc, "eta": eta} + def named_Mc_eta_to_m1_m2(x): m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) -ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta -ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 + +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "eta"]) +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = ( + named_m1_m2_to_Mc_eta +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = ( + named_Mc_eta_to_m1_m2 +) + def named_q_to_eta(x): return {"eta": q_to_eta(x["q"])} + + def named_eta_to_q(x): return {"q": eta_to_q(x["eta"])} + + MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index f0e089fe..300f2132 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -47,10 +47,10 @@ sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi),