diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 712abf1a..e5675c8f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -6,6 +6,9 @@ from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped from beartype import beartype as typechecker +from jimgw.single_event.utils import zenith_azimuth_to_ra_dec +from jimgw.single_event.detector import GroundBased2G, detector_preset +from astropy.time import Time class Prior(Distribution): @@ -393,6 +396,75 @@ def log_prob(self, x: dict[str, Float]) -> Float: return log_p +@jaxtyped(typechecker=typechecker) +class EarthFrame(Prior): + """ + Prior distribution for sky location in Earth frame. + """ + + def __repr__(self): + return f"EarthFrame(naming={self.naming})" + + def __init__(self, naming: str, gps: Float, ifos: list, **kwargs): + self.naming = ["azimuth", "zenith"] + if len(ifos) < 2: + return ValueError( + "At least two detectors are needed to define the Earth frame" + ) + elif isinstance(ifos[0], str): + self.ifos = [detector_preset[ifo] for ifo in ifos[:2]] + elif isinstance(ifos[0], GroundBased2G): + self.ifos = ifos[:2] + else: + return ValueError( + "ifos should be a list of detector names or Detector objects" + ) + self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex + + self.transforms = { + "azimuth": ( + "ra", + lambda params: zenith_azimuth_to_ra_dec( + params["zenith"], + params["azimuth"], + gmst=self.gmst, + delta_x=self.delta_x, + )[0], + ), + "zenith": ( + "dec", + lambda params: zenith_azimuth_to_ra_dec( + params["zenith"], + params["azimuth"], + gmst=self.gmst, + delta_x=self.delta_x, + )[1], + ), + } + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + rng_keys = jax.random.split(rng_key, 2) + zenith = jnp.arccos( + jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) + ) + azimuth = jax.random.uniform( + rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi + ) + return self.add_name(jnp.stack([azimuth, zenith], axis=1).T) + + def log_prob(self, x: dict[str, Float]) -> Float: + zenith = x["zenith"] + azimuth = x["azimuth"] + output = jnp.where( + (azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0), + jnp.zeros_like(0) - jnp.inf, + ) + return output + + @jaxtyped(typechecker=typechecker) class PowerLaw(Prior): """ diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index d520f561..56eb9925 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -58,6 +58,7 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): ) }, ), + "EarthFrame": prior.EarthFrame, } diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index c0b14b7a..72ef5e36 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -118,6 +118,7 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: return m1, m2 +@jit def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ Transforming the right ascension ra and declination dec to the polar angle @@ -144,6 +145,162 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa return theta, phi +@jit +def euler_rotation(delta_x: tuple[Float, Float, Float]): + """ + Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x + while preserving the origin of the azimuthal angle. + + This is decomposed into three Euler angles, alpha, beta, gamma, which rotate + about the z-, y-, and z- axes respectively. + + Copied and modified from bilby-cython/geometry.pyx + """ + norm = jnp.power( + delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5 + ) + cos_beta = delta_x[2] / norm + sin_beta = jnp.power(1 - cos_beta**2, 0.5) + + alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0]) + gamma = jnp.atan2(delta_x[1], delta_x[0]) + + cos_alpha = jnp.cos(alpha) + sin_alpha = jnp.sin(alpha) + cos_gamma = jnp.cos(gamma) + sin_gamma = jnp.sin(gamma) + + rotation = jnp.array( + [ + [ + cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma, + -sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma, + sin_beta * cos_gamma, + ], + [ + cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma, + -sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma, + sin_beta * sin_gamma, + ], + [-cos_alpha * sin_beta, sin_alpha * sin_beta, cos_beta], + ] + ) + + return rotation + + +@jit +def zenith_azimuth_to_theta_phi( + zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float] +) -> tuple[Float, Float]: + """ + Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame. + + Copied and modified from bilby-cython/geometry.pyx + + Parameters + ---------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + delta_x : Float + The vector pointing from the first detector to the second detector. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + sin_azimuth = jnp.sin(azimuth) + cos_azimuth = jnp.cos(azimuth) + sin_zenith = jnp.sin(zenith) + cos_zenith = jnp.cos(zenith) + + rotation = euler_rotation(delta_x) + + theta = jnp.acos( + rotation[2][0] * sin_zenith * cos_azimuth + + rotation[2][1] * sin_zenith * sin_azimuth + + rotation[2][2] * cos_zenith + ) + phi = jnp.fmod( + jnp.atan2( + rotation[1][0] * sin_zenith * cos_azimuth + + rotation[1][1] * sin_zenith * sin_azimuth + + rotation[1][2] * cos_zenith, + rotation[0][0] * sin_zenith * cos_azimuth + + rotation[0][1] * sin_zenith * sin_azimuth + + rotation[0][2] * cos_zenith, + ) + + 2 * jnp.pi, + (2 * jnp.pi), + ) + return theta, phi + + +@jit +def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the polar angle and azimuthal angle to right ascension and declination. + + Parameters + ---------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + ra : Float + Right ascension. + dec : Float + Declination. + """ + ra = phi + gmst + dec = jnp.pi / 2 - theta + return ra, dec + + +@jit +def zenith_azimuth_to_ra_dec( + zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float] +) -> tuple[Float, Float]: + """ + Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. + + Parameters + ---------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + gmst : Float + Greenwich mean sidereal time. + delta_x : Float + The vector pointing from the first detector to the second detector. + + Copied and modified from bilby/gw/utils.py + + Returns + ------- + ra : Float + Right ascension. + dec : Float + Declination. + """ + theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) + ra = ra % (2 * jnp.pi) + return ra, dec + + +@jit def log_i0(x): """ A numerically stable method to evaluate log of