Skip to content

Commit

Permalink
Hard-code transform name_mapping and conditional_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
tsunhopang committed Sep 2, 2024
1 parent 2fbfc04 commit ce7b308
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 30 deletions.
32 changes: 6 additions & 26 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform(

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
conditional_names: list[str],
gps_time: Float,
ifo: GroundBased2G,
tc_min: Float,
tc_max: Float,
):
name_mapping = [["t_c"], ["t_det_unbounded"]]
conditional_names = ["ra", "dec"]
super().__init__(name_mapping, conditional_names)

self.gmst = (
Expand All @@ -156,9 +156,6 @@ def __init__(
self.tc_min = tc_min
self.tc_max = tc_max

assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1]
assert "ra" in conditional_names and "dec" in conditional_names

@jnp.vectorize
def time_delay(ra, dec, gmst):
return self.ifo.delay_from_geocenter(ra, dec, gmst)
Expand Down Expand Up @@ -225,26 +222,18 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
conditional_names: list[str],
gps_time: Float,
ifo: GroundBased2G,
):
name_mapping = [["phase_c"], ["phase_det"]]
conditional_names = ["ra", "dec", "psi", "iota"]
super().__init__(name_mapping, conditional_names)

self.gmst = (
Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad
)
self.ifo = ifo

assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1]
assert (
"ra" in conditional_names
and "dec" in conditional_names
and "psi" in conditional_names
and "iota" in conditional_names
)

@jnp.vectorize
def _calc_R_det_arg(ra, dec, psi, iota, gmst):
p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0
Expand Down Expand Up @@ -298,13 +287,13 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform):

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
conditional_names: list[str],
gps_time: Float,
ifos: list[GroundBased2G],
dL_min: Float,
dL_max: Float,
):
name_mapping = [["d_L"], ["d_hat_unbounded"]]
conditional_names = ["M_c", "ra", "dec", "psi", "iota"]
super().__init__(name_mapping, conditional_names)

self.gmst = (
Expand All @@ -314,15 +303,6 @@ def __init__(
self.dL_min = dL_min
self.dL_max = dL_max

assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1]
assert (
"ra" in conditional_names
and "dec" in conditional_names
and "psi" in conditional_names
and "iota" in conditional_names
and "M_c" in conditional_names
)

@jnp.vectorize
def _calc_R_dets(ra, dec, psi, iota):
p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0
Expand Down
8 changes: 4 additions & 4 deletions test/integration/test_extrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@

sample_transforms = [
# all the user reparametrization transform
DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]),
SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos),
DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
# all the bound to unbound transform
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0),
BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi),
Expand Down

0 comments on commit ce7b308

Please sign in to comment.