Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EarthFrame prior #87

Merged
merged 9 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from flowMC.nfmodel.base import Distribution
from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped
from beartype import beartype as typechecker
from jimgw.single_event.utils import zenith_azimuth_to_ra_dec
from jimgw.single_event.detector import GroundBased2G, detector_preset
from astropy.time import Time


class Prior(Distribution):
Expand Down Expand Up @@ -393,6 +396,75 @@ def log_prob(self, x: dict[str, Float]) -> Float:
return log_p


@jaxtyped(typechecker=typechecker)
class EarthFrame(Prior):
"""
Prior distribution for sky location in Earth frame.
"""

def __repr__(self):
return f"EarthFrame(naming={self.naming})"

def __init__(self, naming: str, gps: Float, ifos: list, **kwargs):
self.naming = ["azimuth", "zenith"]
if len(ifos) < 2:
return ValueError(
"At least two detectors are needed to define the Earth frame"
)
elif isinstance(ifos[0], str):
self.ifos = [detector_preset[ifo] for ifo in ifos[:2]]
elif isinstance(ifos[0], GroundBased2G):
self.ifos = ifos[:2]
else:
return ValueError(
"ifos should be a list of detector names or Detector objects"
)
self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex

self.transforms = {
"azimuth": (
"ra",
lambda params: zenith_azimuth_to_ra_dec(
params["zenith"],
params["azimuth"],
gmst=self.gmst,
delta_x=self.delta_x,
)[0],
),
"zenith": (
"dec",
lambda params: zenith_azimuth_to_ra_dec(
params["zenith"],
params["azimuth"],
gmst=self.gmst,
delta_x=self.delta_x,
)[1],
),
}

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
rng_keys = jax.random.split(rng_key, 2)
zenith = jnp.arccos(
jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)
)
azimuth = jax.random.uniform(
rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi
)
return self.add_name(jnp.stack([azimuth, zenith], axis=1).T)

def log_prob(self, x: dict[str, Float]) -> Float:
zenith = x["zenith"]
azimuth = x["azimuth"]
output = jnp.where(
(azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0),
jnp.zeros_like(0) - jnp.inf,
)
return output


@jaxtyped(typechecker=typechecker)
class PowerLaw(Prior):
"""
Expand Down
1 change: 1 addition & 0 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl):
)
},
),
"EarthFrame": prior.EarthFrame,
}


Expand Down
157 changes: 157 additions & 0 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]:
return m1, m2


@jit
def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]:
"""
Transforming the right ascension ra and declination dec to the polar angle
Expand All @@ -144,6 +145,162 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa
return theta, phi


@jit
def euler_rotation(delta_x: tuple[Float, Float, Float]):
"""
Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x
while preserving the origin of the azimuthal angle.

This is decomposed into three Euler angles, alpha, beta, gamma, which rotate
about the z-, y-, and z- axes respectively.

Copied and modified from bilby-cython/geometry.pyx
"""
norm = jnp.power(
delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5
)
cos_beta = delta_x[2] / norm
sin_beta = jnp.power(1 - cos_beta**2, 0.5)

alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0])
gamma = jnp.atan2(delta_x[1], delta_x[0])

cos_alpha = jnp.cos(alpha)
sin_alpha = jnp.sin(alpha)
cos_gamma = jnp.cos(gamma)
sin_gamma = jnp.sin(gamma)

rotation = jnp.array(
[
[
cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma,
-sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma,
sin_beta * cos_gamma,
],
[
cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma,
-sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma,
sin_beta * sin_gamma,
],
[-cos_alpha * sin_beta, sin_alpha * sin_beta, cos_beta],
]
)

return rotation


@jit
def zenith_azimuth_to_theta_phi(
zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float]
) -> tuple[Float, Float]:
"""
Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame.

Copied and modified from bilby-cython/geometry.pyx

Parameters
----------
zenith : Float
Zenith angle.
azimuth : Float
Azimuthal angle.
delta_x : Float
The vector pointing from the first detector to the second detector.

Returns
-------
theta : Float
Polar angle.
phi : Float
Azimuthal angle.
"""
sin_azimuth = jnp.sin(azimuth)
cos_azimuth = jnp.cos(azimuth)
sin_zenith = jnp.sin(zenith)
cos_zenith = jnp.cos(zenith)

rotation = euler_rotation(delta_x)

theta = jnp.acos(
rotation[2][0] * sin_zenith * cos_azimuth
+ rotation[2][1] * sin_zenith * sin_azimuth
+ rotation[2][2] * cos_zenith
)
phi = jnp.fmod(
jnp.atan2(
rotation[1][0] * sin_zenith * cos_azimuth
+ rotation[1][1] * sin_zenith * sin_azimuth
+ rotation[1][2] * cos_zenith,
rotation[0][0] * sin_zenith * cos_azimuth
+ rotation[0][1] * sin_zenith * sin_azimuth
+ rotation[0][2] * cos_zenith,
)
+ 2 * jnp.pi,
(2 * jnp.pi),
)
return theta, phi


@jit
def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, Float]:
"""
Transforming the polar angle and azimuthal angle to right ascension and declination.

Parameters
----------
theta : Float
Polar angle.
phi : Float
Azimuthal angle.
gmst : Float
Greenwich mean sidereal time.

Returns
-------
ra : Float
Right ascension.
dec : Float
Declination.
"""
ra = phi + gmst
dec = jnp.pi / 2 - theta
return ra, dec


@jit
def zenith_azimuth_to_ra_dec(
zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float]
) -> tuple[Float, Float]:
"""
Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination.

Parameters
----------
zenith : Float
Zenith angle.
azimuth : Float
Azimuthal angle.
gmst : Float
Greenwich mean sidereal time.
delta_x : Float
The vector pointing from the first detector to the second detector.

Copied and modified from bilby/gw/utils.py

Returns
-------
ra : Float
Right ascension.
dec : Float
Declination.
"""
theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x)
ra, dec = theta_phi_to_ra_dec(theta, phi, gmst)
ra = ra % (2 * jnp.pi)
return ra, dec


@jit
def log_i0(x):
"""
A numerically stable method to evaluate log of
Expand Down
Loading