Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasckng committed Jul 10, 2024
1 parent 1325620 commit abb7d9c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 29 deletions.
44 changes: 33 additions & 11 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def log_prob(self, x: dict[str, Float]) -> Float:
jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax),
)
return log_p


@jaxtyped(typechecker=typechecker)
class EarthFrame(Prior):
Expand All @@ -408,34 +408,56 @@ def __repr__(self):
def __init__(self, naming: str, gps: Float, ifos: list, **kwargs):
self.naming = ["azimuth", "zenith"]
if len(ifos) < 2:
return ValueError("At least two detectors are needed to define the Earth frame")
return ValueError(
"At least two detectors are needed to define the Earth frame"
)
elif isinstance(ifos[0], str):
self.ifos = [detector_preset[ifo] for ifo in ifos[:2]]
elif isinstance(ifos[0], GroundBased2G):
self.ifos = ifos[:2]
else:
return ValueError("ifos should be a list of detector names or Detector objects")
return ValueError(
"ifos should be a list of detector names or Detector objects"
)
self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex

self.transforms = {
"azimuth": (
"ra", lambda params: zenith_azimuth_to_ra_dec(params["zenith"], params["azimuth"], gmst=self.gmst, delta_x=self.delta_x)[0]
"ra",
lambda params: zenith_azimuth_to_ra_dec(
params["zenith"],
params["azimuth"],
gmst=self.gmst,
delta_x=self.delta_x,
)[0],
),
"zenith": (
"dec", lambda params: zenith_azimuth_to_ra_dec(params["zenith"], params["azimuth"], gmst=self.gmst, delta_x=self.delta_x)[1]
"dec",
lambda params: zenith_azimuth_to_ra_dec(
params["zenith"],
params["azimuth"],
gmst=self.gmst,
delta_x=self.delta_x,
)[1],
),
}

def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> dict[str, Float[Array, " n_samples"]]:
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
rng_keys = jax.random.split(rng_key, 2)
zenith = jnp.arccos(jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0))
azimuth = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi)
zenith = jnp.arccos(
jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)
)
azimuth = jax.random.uniform(
rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi
)
return self.add_name(jnp.stack([azimuth, zenith], axis=1).T)

def log_prob(self, x: dict[str, Float]) -> Float:
zenith = x['zenith']
azimuth = x['azimuth']
zenith = x["zenith"]
azimuth = x["azimuth"]
output = jnp.where(
(azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0),
jnp.zeros_like(0) - jnp.inf,
Expand Down
61 changes: 43 additions & 18 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]:
m2 = m1 * q
return m1, m2


@jit
def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]:
"""
Expand Down Expand Up @@ -156,7 +157,9 @@ def euler_rotation(delta_x: tuple[Float, Float, Float]):
Copied and modified from bilby-cython/geometry.pyx
"""
norm = jnp.power(delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5)
norm = jnp.power(
delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5
)
cos_beta = delta_x[2] / norm
sin_beta = jnp.power(1 - cos_beta**2, 0.5)

Expand All @@ -168,15 +171,29 @@ def euler_rotation(delta_x: tuple[Float, Float, Float]):
cos_gamma = jnp.cos(gamma)
sin_gamma = jnp.sin(gamma)

rotation = jnp.array([[cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma, -sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma, sin_beta * cos_gamma],
[cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma, -sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma, sin_beta * sin_gamma],
[-cos_alpha * sin_beta, sin_alpha * sin_beta, cos_beta]])
rotation = jnp.array(
[
[
cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma,
-sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma,
sin_beta * cos_gamma,
],
[
cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma,
-sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma,
sin_beta * sin_gamma,
],
[-cos_alpha * sin_beta, sin_alpha * sin_beta, cos_beta],
]
)

return rotation


@jit
def zenith_azimuth_to_theta_phi(zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float]) -> tuple[Float, Float]:
def zenith_azimuth_to_theta_phi(
zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float]
) -> tuple[Float, Float]:
"""
Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame.
Expand Down Expand Up @@ -205,18 +222,23 @@ def zenith_azimuth_to_theta_phi(zenith: Float, azimuth: Float, delta_x: tuple[Fl

rotation = euler_rotation(delta_x)

theta = jnp.acos(rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth + rotation[2][2] * cos_zenith)
theta = jnp.acos(
rotation[2][0] * sin_zenith * cos_azimuth
+ rotation[2][1] * sin_zenith * sin_azimuth
+ rotation[2][2] * cos_zenith
)
phi = jnp.fmod(
jnp.atan2(
rotation[1][0] * sin_zenith * cos_azimuth
+ rotation[1][1] * sin_zenith * sin_azimuth
+ rotation[1][2] * cos_zenith,
rotation[0][0] * sin_zenith * cos_azimuth
+ rotation[0][1] * sin_zenith * sin_azimuth
+ rotation[0][2] * cos_zenith)
+ 2 * jnp.pi,
(2 * jnp.pi)
)
jnp.atan2(
rotation[1][0] * sin_zenith * cos_azimuth
+ rotation[1][1] * sin_zenith * sin_azimuth
+ rotation[1][2] * cos_zenith,
rotation[0][0] * sin_zenith * cos_azimuth
+ rotation[0][1] * sin_zenith * sin_azimuth
+ rotation[0][2] * cos_zenith,
)
+ 2 * jnp.pi,
(2 * jnp.pi),
)
return theta, phi


Expand Down Expand Up @@ -247,9 +269,11 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F


@jit
def zenith_azimuth_to_ra_dec(zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float]) -> tuple[Float, Float]:
def zenith_azimuth_to_ra_dec(
zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float]
) -> tuple[Float, Float]:
"""
Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination.
Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination.
Parameters
----------
Expand All @@ -276,6 +300,7 @@ def zenith_azimuth_to_ra_dec(zenith: Float, azimuth: Float, gmst: Float, delta_x
ra = ra % (2 * jnp.pi)
return ra, dec


@jit
def log_i0(x):
"""
Expand Down

0 comments on commit abb7d9c

Please sign in to comment.