Skip to content

Commit

Permalink
Added inverse_conditional_names
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyuon committed Sep 4, 2024
1 parent 2949604 commit d090330
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
76 changes: 71 additions & 5 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def __init__(
):
name_mapping = [["t_c"], ["t_det_unbounded"]]
conditional_names = ["ra", "dec"]
super().__init__(name_mapping, conditional_names)
inverse_conditional_names = ["ra", "dec"]
super().__init__(name_mapping, conditional_names, inverse_conditional_names)

self.gmst = (
Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad
Expand Down Expand Up @@ -231,8 +232,40 @@ def __init__(
has_iota: bool = True,
):
name_mapping = [["phase_c"], ["phase_det"]]
conditional_names = ["ra", "dec", "psi", "iota"]
super().__init__(name_mapping, conditional_names)
if has_iota:
conditional_names = ["ra", "dec", "psi", "iota"]
inverse_conditional_names = ["ra", "dec", "psi", "iota"]
else:
conditional_names = [
"ra",
"dec",
"psi",
"theta_jn",
"phi_jl",
"theta_1",
"theta_2",
"phi_12",
"a_1",
"a_2",
"M_c",
"q",
"phase_c"
]
inverse_conditional_names = [
"ra",
"dec",
"psi",
"theta_jn",
"phi_jl",
"theta_1",
"theta_2",
"phi_12",
"a_1",
"a_2",
"M_c",
"q",
]
super().__init__(name_mapping, conditional_names, inverse_conditional_names)

self.gmst = (
Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad
Expand Down Expand Up @@ -339,8 +372,41 @@ def __init__(
has_iota: bool = True,
):
name_mapping = [["d_L"], ["d_hat_unbounded"]]
conditional_names = ["M_c", "ra", "dec", "psi", "iota"]
super().__init__(name_mapping, conditional_names)
if has_iota:
conditional_names = ["ra", "dec", "psi", "iota"]
inverse_conditional_names = ["ra", "dec", "psi", "iota"]
else:
conditional_names = [
"ra",
"dec",
"psi",
"theta_jn",
"phi_jl",
"theta_1",
"theta_2",
"phi_12",
"a_1",
"a_2",
"M_c",
"q",
"phase_c"
]
inverse_conditional_names = [
"ra",
"dec",
"psi",
"theta_jn",
"phi_jl",
"theta_1",
"theta_2",
"phi_12",
"a_1",
"a_2",
"M_c",
"q",
"phase_c"
]
super().__init__(name_mapping, conditional_names, inverse_conditional_names)

self.gmst = (
Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad
Expand Down
5 changes: 4 additions & 1 deletion src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,17 @@ def backward(self, y: dict[str, Float]) -> dict[str, Float]:
class ConditionalBijectiveTransform(BijectiveTransform):

conditional_names: list[str]
inverse_conditional_names: list[str]

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
conditional_names: list[str],
inverse_conditional_names: list[str],
):
super().__init__(name_mapping)
self.conditional_names = conditional_names
self.inverse_conditional_names = inverse_conditional_names

def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
x_copy = x.copy()
Expand Down Expand Up @@ -204,7 +207,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]:
y_copy = y.copy()
transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1])
transform_params.update(
dict((key, y_copy[key]) for key in self.conditional_names)
dict((key, y_copy[key]) for key in self.inverse_conditional_names)
)
output_params = self.inverse_transform_func(transform_params)
jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params)
Expand Down

0 comments on commit d090330

Please sign in to comment.