Skip to content

Commit

Permalink
Merge pull request kazewong#87 from thomasckng/main
Browse files Browse the repository at this point in the history
Add EarthFrame prior
  • Loading branch information
kazewong authored Jul 10, 2024
2 parents 5a5e45d + edf316a commit 94c34db
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 0 deletions.
72 changes: 72 additions & 0 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from flowMC.nfmodel.base import Distribution
from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped
from beartype import beartype as typechecker
from jimgw.single_event.utils import zenith_azimuth_to_ra_dec
from jimgw.single_event.detector import GroundBased2G, detector_preset
from astropy.time import Time


class Prior(Distribution):
Expand Down Expand Up @@ -393,6 +396,75 @@ def log_prob(self, x: dict[str, Float]) -> Float:
return log_p


@jaxtyped(typechecker=typechecker)
class EarthFrame(Prior):
"""
Prior distribution for sky location in Earth frame.
"""

def __repr__(self):
return f"EarthFrame(naming={self.naming})"

def __init__(self, naming: str, gps: Float, ifos: list, **kwargs):
self.naming = ["azimuth", "zenith"]
if len(ifos) < 2:
return ValueError(
"At least two detectors are needed to define the Earth frame"
)
elif isinstance(ifos[0], str):
self.ifos = [detector_preset[ifo] for ifo in ifos[:2]]
elif isinstance(ifos[0], GroundBased2G):
self.ifos = ifos[:2]
else:
return ValueError(
"ifos should be a list of detector names or Detector objects"
)
self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex

self.transforms = {
"azimuth": (
"ra",
lambda params: zenith_azimuth_to_ra_dec(
params["zenith"],
params["azimuth"],
gmst=self.gmst,
delta_x=self.delta_x,
)[0],
),
"zenith": (
"dec",
lambda params: zenith_azimuth_to_ra_dec(
params["zenith"],
params["azimuth"],
gmst=self.gmst,
delta_x=self.delta_x,
)[1],
),
}

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
rng_keys = jax.random.split(rng_key, 2)
zenith = jnp.arccos(
jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)
)
azimuth = jax.random.uniform(
rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi
)
return self.add_name(jnp.stack([azimuth, zenith], axis=1).T)

def log_prob(self, x: dict[str, Float]) -> Float:
zenith = x["zenith"]
azimuth = x["azimuth"]
output = jnp.where(
(azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0),
jnp.zeros_like(0) - jnp.inf,
)
return output


@jaxtyped(typechecker=typechecker)
class PowerLaw(Prior):
"""
Expand Down
1 change: 1 addition & 0 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl):
)
},
),
"EarthFrame": prior.EarthFrame,
}


Expand Down
157 changes: 157 additions & 0 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]:
return m1, m2


@jit
def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]:
"""
Transforming the right ascension ra and declination dec to the polar angle
Expand All @@ -144,6 +145,162 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa
return theta, phi


@jit
def euler_rotation(delta_x: tuple[Float, Float, Float]):
"""
Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x
while preserving the origin of the azimuthal angle.
This is decomposed into three Euler angles, alpha, beta, gamma, which rotate
about the z-, y-, and z- axes respectively.
Copied and modified from bilby-cython/geometry.pyx
"""
norm = jnp.power(
delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5
)
cos_beta = delta_x[2] / norm
sin_beta = jnp.power(1 - cos_beta**2, 0.5)

alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0])
gamma = jnp.atan2(delta_x[1], delta_x[0])

cos_alpha = jnp.cos(alpha)
sin_alpha = jnp.sin(alpha)
cos_gamma = jnp.cos(gamma)
sin_gamma = jnp.sin(gamma)

rotation = jnp.array(
[
[
cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma,
-sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma,
sin_beta * cos_gamma,
],
[
cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma,
-sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma,
sin_beta * sin_gamma,
],
[-cos_alpha * sin_beta, sin_alpha * sin_beta, cos_beta],
]
)

return rotation


@jit
def zenith_azimuth_to_theta_phi(
zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float]
) -> tuple[Float, Float]:
"""
Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame.
Copied and modified from bilby-cython/geometry.pyx
Parameters
----------
zenith : Float
Zenith angle.
azimuth : Float
Azimuthal angle.
delta_x : Float
The vector pointing from the first detector to the second detector.
Returns
-------
theta : Float
Polar angle.
phi : Float
Azimuthal angle.
"""
sin_azimuth = jnp.sin(azimuth)
cos_azimuth = jnp.cos(azimuth)
sin_zenith = jnp.sin(zenith)
cos_zenith = jnp.cos(zenith)

rotation = euler_rotation(delta_x)

theta = jnp.acos(
rotation[2][0] * sin_zenith * cos_azimuth
+ rotation[2][1] * sin_zenith * sin_azimuth
+ rotation[2][2] * cos_zenith
)
phi = jnp.fmod(
jnp.atan2(
rotation[1][0] * sin_zenith * cos_azimuth
+ rotation[1][1] * sin_zenith * sin_azimuth
+ rotation[1][2] * cos_zenith,
rotation[0][0] * sin_zenith * cos_azimuth
+ rotation[0][1] * sin_zenith * sin_azimuth
+ rotation[0][2] * cos_zenith,
)
+ 2 * jnp.pi,
(2 * jnp.pi),
)
return theta, phi


@jit
def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, Float]:
"""
Transforming the polar angle and azimuthal angle to right ascension and declination.
Parameters
----------
theta : Float
Polar angle.
phi : Float
Azimuthal angle.
gmst : Float
Greenwich mean sidereal time.
Returns
-------
ra : Float
Right ascension.
dec : Float
Declination.
"""
ra = phi + gmst
dec = jnp.pi / 2 - theta
return ra, dec


@jit
def zenith_azimuth_to_ra_dec(
zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float]
) -> tuple[Float, Float]:
"""
Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination.
Parameters
----------
zenith : Float
Zenith angle.
azimuth : Float
Azimuthal angle.
gmst : Float
Greenwich mean sidereal time.
delta_x : Float
The vector pointing from the first detector to the second detector.
Copied and modified from bilby/gw/utils.py
Returns
-------
ra : Float
Right ascension.
dec : Float
Declination.
"""
theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x)
ra, dec = theta_phi_to_ra_dec(theta, phi, gmst)
ra = ra % (2 * jnp.pi)
return ra, dec


@jit
def log_i0(x):
"""
A numerically stable method to evaluate log of
Expand Down

0 comments on commit 94c34db

Please sign in to comment.