From 801bdb96ae760c7bbf71641ed747c3d8659c07e3 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Tue, 2 Jul 2024 13:36:23 +0800 Subject: [PATCH 1/8] Add EarthFrame prior [WIP] --- src/jimgw/prior.py | 38 +++++++++ src/jimgw/single_event/utils.py | 141 +++++++++++++++++++++++++++++++- 2 files changed, 178 insertions(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 6408318d..545bee23 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -391,6 +391,44 @@ def log_prob(self, x: dict[str, Float]) -> Float: jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), ) return log_p + +class EarthFrame(Prior): + """ + Prior distribution for sky location in Earth frame. + """ + + def __repr__(self): + return f"EarthFrame(naming={self.naming})" + + def __init__(self, naming: str, **kwargs): + self.naming = ["azimuth", "zenith"] + self.transforms = { + "azimuth": ( + "ra", + lambda params: + ), + "zenith": ( + "dec", + lambda params: + ), + } + + def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> dict[str, Float[Array, " n_samples"]]: + rng_keys = jax.random.split(rng_key, 2) + zenith = jnp.arccos( + jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) + ) + azimuth = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) + return self.add_name(jnp.stack([azimuth, zenith], axis=1).T) + + def log_prob(self, x: dict[str, Float]) -> Float: + zenith = x['zenith'] + azimuth = x['azimuth'] + output = jnp.where( + (azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0), + jnp.zeros_like(0) - jnp.inf, + ) + return output @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index c0b14b7a..8a0ee633 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -117,7 +117,7 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: m2 = m1 * q return m1, m2 - +@jit def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ Transforming the right ascension ra and declination dec to the polar angle @@ -144,6 +144,145 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa return theta, phi +@jit +def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the polar angle and azimuthal angle to right ascension and declination. + + Parameters + ---------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + ra : Float + Right ascension. + dec : Float + Declination. + """ + ra = phi + gmst + dec = jnp.pi / 2 - theta + return ra, dec + + +@jit +def euler_rotation(delta_x: tuple[Float, Float, Float]): + """ + Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x + while preserving the origin of the azimuthal angle. + + This is decomposed into three Euler angles, alpha, beta, gamma, which rotate + about the z-, y-, and z- axes respectively. + + Copied and modified from bilby-cython/geometry.pyx + """ + norm = jnp.power(delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5) + cos_beta = delta_x[2] / norm + sin_beta = jnp.power(1 - cos_beta**2, 0.5) + + alpha = jnp.atan2(- delta_x[1] * cos_beta, delta_x[0]) + gamma = jnp.atan2(delta_x[1], delta_x[0]) + + cos_alpha = jnp.cos(alpha) + sin_alpha = jnp.sin(alpha) + cos_gamma = jnp.cos(gamma) + sin_gamma = jnp.sin(gamma) + + rotation = jnp.empty((3, 3)) + + rotation[0][0] = cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma + rotation[1][0] = cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma + rotation[2][0] = -cos_alpha * sin_beta + rotation[0][1] = -sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma + rotation[1][1] = -sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma + rotation[2][1] = sin_alpha * sin_beta + rotation[0][2] = sin_beta * cos_gamma + rotation[1][2] = sin_beta * sin_gamma + rotation[2][2] = cos_beta + + return rotation + + +@jit +def zenith_azimuth_to_theta_phi(zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float]) -> tuple[Float, Float]: + """ + Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame. + + Copied and modified from bilby-cython/geometry.pyx + + Parameters + ---------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + delta_x : Float + The vector pointing from the first detector to the second detector. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + sin_azimuth = jnp.sin(azimuth) + cos_azimuth = jnp.cos(azimuth) + sin_zenith = jnp.sin(zenith) + cos_zenith = jnp.cos(zenith) + + rotation = euler_rotation(delta_x) + + theta = jnp.acos(rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth + rotation[2][2] * cos_zenith + ) + phi = jnp.fmod( + jnp.atan2( + rotation[1][0] * sin_zenith * cos_azimuth + + rotation[1][1] * sin_zenith * sin_azimuth + + rotation[1][2] * cos_zenith, + rotation[0][0] * sin_zenith * cos_azimuth + + rotation[0][1] * sin_zenith * sin_azimuth + + rotation[0][2] * cos_zenith) + + 2 * jnp.pi, + (2 * jnp.pi) + ) + return theta, phi + + +@jit +def azimuth_zenith_to_ra_dec(azimuth: Float, zenith: Float, geocent_time: Float, ifos: list) -> tuple[Float, Float]: + """ + Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. + + Parameters + ---------- + azimuth : Float + Azimuthal angle. + zenith : Float + Zenith angle. + + Copied and modified from bilby/gw/utils.py + + Returns + ------- + ra : Float + Right ascension. + dec : Float + Declination. + """ + delta_x = ifos[0].vertex - ifos[1].vertex + theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + gmst = greenwich_mean_sidereal_time(geocent_time) + ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) + ra = ra % (2 * jnp.pi) + return ra, dec + + def log_i0(x): """ A numerically stable method to evaluate log of From eba2349b3696c3463532f761695bbd2f6f61c6d9 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 3 Jul 2024 02:51:37 +0800 Subject: [PATCH 2/8] Finish EarthFrame prior --- src/jimgw/prior.py | 24 +++++++---- src/jimgw/single_event/runManager.py | 1 + src/jimgw/single_event/utils.py | 60 ++++++++++++++-------------- 3 files changed, 47 insertions(+), 38 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 545bee23..e047970e 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -6,6 +6,9 @@ from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped from beartype import beartype as typechecker +from jimgw.single_event.utils import azimuth_zenith_to_ra_dec +from jimgw.single_event.detector import Detector, detector_preset +from astropy.time import Time class Prior(Distribution): @@ -400,24 +403,29 @@ class EarthFrame(Prior): def __repr__(self): return f"EarthFrame(naming={self.naming})" - def __init__(self, naming: str, **kwargs): + def __init__(self, naming: str, gps: Float, ifos: list, **kwargs): self.naming = ["azimuth", "zenith"] + if len(ifos) < 2: + return ValueError("At least two detectors are needed to define the Earth frame") + elif isinstance(ifos[0], str): + self.ifos = [detector_preset[ifo] for ifo in ifos[:2]] + elif isinstance(ifos[0], Detector): + self.ifos = ifos[:2] + else: + return ValueError("ifos should be a list of detector names or Detector objects") + self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad self.transforms = { "azimuth": ( - "ra", - lambda params: + "ra", lambda params: azimuth_zenith_to_ra_dec(params["azimuth"], params["zenith"], gmst=self.gmst, ifos=ifos)[0] ), "zenith": ( - "dec", - lambda params: + "dec", lambda params: azimuth_zenith_to_ra_dec(params["azimuth"], params["zenith"], gmst=self.gmst, ifos=ifos)[1] ), } def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> dict[str, Float[Array, " n_samples"]]: rng_keys = jax.random.split(rng_key, 2) - zenith = jnp.arccos( - jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) - ) + zenith = jnp.arccos(jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)) azimuth = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) return self.add_name(jnp.stack([azimuth, zenith], axis=1).T) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index d520f561..56eb9925 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -58,6 +58,7 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): ) }, ), + "EarthFrame": prior.EarthFrame, } diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 8a0ee633..2e094a97 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -3,6 +3,7 @@ from jax.scipy.integrate import trapezoid from jax.scipy.special import i0e from jaxtyping import Array, Float +from jimgw.single_event.detector import Detector @jit @@ -144,32 +145,6 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa return theta, phi -@jit -def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, Float]: - """ - Transforming the polar angle and azimuthal angle to right ascension and declination. - - Parameters - ---------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. - gmst : Float - Greenwich mean sidereal time. - - Returns - ------- - ra : Float - Right ascension. - dec : Float - Declination. - """ - ra = phi + gmst - dec = jnp.pi / 2 - theta - return ra, dec - - @jit def euler_rotation(delta_x: tuple[Float, Float, Float]): """ @@ -255,9 +230,35 @@ def zenith_azimuth_to_theta_phi(zenith: Float, azimuth: Float, delta_x: tuple[Fl @jit -def azimuth_zenith_to_ra_dec(azimuth: Float, zenith: Float, geocent_time: Float, ifos: list) -> tuple[Float, Float]: +def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the polar angle and azimuthal angle to right ascension and declination. + + Parameters + ---------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + ra : Float + Right ascension. + dec : Float + Declination. + """ + ra = phi + gmst + dec = jnp.pi / 2 - theta + return ra, dec + + +@jit +def azimuth_zenith_to_ra_dec(azimuth: Float, zenith: Float, gmst: Float, ifos: list[Detector]) -> tuple[Float, Float]: """ - Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. + Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. Parameters ---------- @@ -277,12 +278,11 @@ def azimuth_zenith_to_ra_dec(azimuth: Float, zenith: Float, geocent_time: Float, """ delta_x = ifos[0].vertex - ifos[1].vertex theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) - gmst = greenwich_mean_sidereal_time(geocent_time) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) ra = ra % (2 * jnp.pi) return ra, dec - +@jit def log_i0(x): """ A numerically stable method to evaluate log of From e8e645cacbefef46d0ee803ff5dd70eed7053e60 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 4 Jul 2024 09:55:42 +0800 Subject: [PATCH 3/8] Change to GroundBased2G --- src/jimgw/prior.py | 4 ++-- src/jimgw/single_event/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 02ed0896..bb43b9b2 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -7,7 +7,7 @@ from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped from beartype import beartype as typechecker from jimgw.single_event.utils import azimuth_zenith_to_ra_dec -from jimgw.single_event.detector import Detector, detector_preset +from jimgw.single_event.detector import GroundBased2G, detector_preset from astropy.time import Time @@ -409,7 +409,7 @@ def __init__(self, naming: str, gps: Float, ifos: list, **kwargs): return ValueError("At least two detectors are needed to define the Earth frame") elif isinstance(ifos[0], str): self.ifos = [detector_preset[ifo] for ifo in ifos[:2]] - elif isinstance(ifos[0], Detector): + elif isinstance(ifos[0], GroundBased2G): self.ifos = ifos[:2] else: return ValueError("ifos should be a list of detector names or Detector objects") diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 2e094a97..0130324e 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -3,7 +3,7 @@ from jax.scipy.integrate import trapezoid from jax.scipy.special import i0e from jaxtyping import Array, Float -from jimgw.single_event.detector import Detector +from jimgw.single_event.detector import GroundBased2G @jit @@ -256,7 +256,7 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F @jit -def azimuth_zenith_to_ra_dec(azimuth: Float, zenith: Float, gmst: Float, ifos: list[Detector]) -> tuple[Float, Float]: +def azimuth_zenith_to_ra_dec(azimuth: Float, zenith: Float, gmst: Float, ifos: list[GroundBased2G]) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. From fba5f791de0cbdd61898c1aced8a82c0463a2706 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 4 Jul 2024 10:37:43 +0800 Subject: [PATCH 4/8] Update prior.py --- src/jimgw/prior.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index bb43b9b2..b95eb7bc 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -224,6 +224,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) +@jaxtyped(typechecker=typechecker) class Sphere(Prior): """ A prior on a sphere represented by Cartesian coordinates. @@ -395,6 +396,8 @@ def log_prob(self, x: dict[str, Float]) -> Float: ) return log_p + +@jaxtyped(typechecker=typechecker) class EarthFrame(Prior): """ Prior distribution for sky location in Earth frame. From 137ecb40a50a96125a899546f72509914ab0adfd Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Thu, 4 Jul 2024 10:41:22 +0800 Subject: [PATCH 5/8] Update prior.py --- src/jimgw/prior.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index b95eb7bc..176bb7be 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -224,7 +224,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) -@jaxtyped(typechecker=typechecker) class Sphere(Prior): """ A prior on a sphere represented by Cartesian coordinates. From 1325620261b6ba401d6f7a8805059941e5a26130 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 8 Jul 2024 10:25:41 +0800 Subject: [PATCH 6/8] Debug --- src/jimgw/prior.py | 8 +++++--- src/jimgw/single_event/utils.py | 30 ++++++++++++------------------ 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 176bb7be..9a101f25 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -6,7 +6,7 @@ from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped from beartype import beartype as typechecker -from jimgw.single_event.utils import azimuth_zenith_to_ra_dec +from jimgw.single_event.utils import zenith_azimuth_to_ra_dec from jimgw.single_event.detector import GroundBased2G, detector_preset from astropy.time import Time @@ -416,12 +416,14 @@ def __init__(self, naming: str, gps: Float, ifos: list, **kwargs): else: return ValueError("ifos should be a list of detector names or Detector objects") self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex + self.transforms = { "azimuth": ( - "ra", lambda params: azimuth_zenith_to_ra_dec(params["azimuth"], params["zenith"], gmst=self.gmst, ifos=ifos)[0] + "ra", lambda params: zenith_azimuth_to_ra_dec(params["zenith"], params["azimuth"], gmst=self.gmst, delta_x=self.delta_x)[0] ), "zenith": ( - "dec", lambda params: azimuth_zenith_to_ra_dec(params["azimuth"], params["zenith"], gmst=self.gmst, ifos=ifos)[1] + "dec", lambda params: zenith_azimuth_to_ra_dec(params["zenith"], params["azimuth"], gmst=self.gmst, delta_x=self.delta_x)[1] ), } diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 0130324e..a7ec5727 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -160,7 +160,7 @@ def euler_rotation(delta_x: tuple[Float, Float, Float]): cos_beta = delta_x[2] / norm sin_beta = jnp.power(1 - cos_beta**2, 0.5) - alpha = jnp.atan2(- delta_x[1] * cos_beta, delta_x[0]) + alpha = jnp.atan2(-delta_x[1] * cos_beta, delta_x[0]) gamma = jnp.atan2(delta_x[1], delta_x[0]) cos_alpha = jnp.cos(alpha) @@ -168,17 +168,9 @@ def euler_rotation(delta_x: tuple[Float, Float, Float]): cos_gamma = jnp.cos(gamma) sin_gamma = jnp.sin(gamma) - rotation = jnp.empty((3, 3)) - - rotation[0][0] = cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma - rotation[1][0] = cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma - rotation[2][0] = -cos_alpha * sin_beta - rotation[0][1] = -sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma - rotation[1][1] = -sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma - rotation[2][1] = sin_alpha * sin_beta - rotation[0][2] = sin_beta * cos_gamma - rotation[1][2] = sin_beta * sin_gamma - rotation[2][2] = cos_beta + rotation = jnp.array([[cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma, -sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma, sin_beta * cos_gamma], + [cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma, -sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma, sin_beta * sin_gamma], + [-cos_alpha * sin_beta, sin_alpha * sin_beta, cos_beta]]) return rotation @@ -213,8 +205,7 @@ def zenith_azimuth_to_theta_phi(zenith: Float, azimuth: Float, delta_x: tuple[Fl rotation = euler_rotation(delta_x) - theta = jnp.acos(rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth + rotation[2][2] * cos_zenith - ) + theta = jnp.acos(rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth + rotation[2][2] * cos_zenith) phi = jnp.fmod( jnp.atan2( rotation[1][0] * sin_zenith * cos_azimuth @@ -256,16 +247,20 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F @jit -def azimuth_zenith_to_ra_dec(azimuth: Float, zenith: Float, gmst: Float, ifos: list[GroundBased2G]) -> tuple[Float, Float]: +def zenith_azimuth_to_ra_dec(zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float]) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. Parameters ---------- - azimuth : Float - Azimuthal angle. zenith : Float Zenith angle. + azimuth : Float + Azimuthal angle. + gmst : Float + Greenwich mean sidereal time. + delta_x : Float + The vector pointing from the first detector to the second detector. Copied and modified from bilby/gw/utils.py @@ -276,7 +271,6 @@ def azimuth_zenith_to_ra_dec(azimuth: Float, zenith: Float, gmst: Float, ifos: l dec : Float Declination. """ - delta_x = ifos[0].vertex - ifos[1].vertex theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) ra = ra % (2 * jnp.pi) From abb7d9c2de395f7a09c5122258e0a1a5737165d3 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 10 Jul 2024 09:15:04 +0800 Subject: [PATCH 7/8] Reformat --- src/jimgw/prior.py | 44 ++++++++++++++++++------ src/jimgw/single_event/utils.py | 61 +++++++++++++++++++++++---------- 2 files changed, 76 insertions(+), 29 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 9a101f25..e5675c8f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -394,7 +394,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax), ) return log_p - + @jaxtyped(typechecker=typechecker) class EarthFrame(Prior): @@ -408,34 +408,56 @@ def __repr__(self): def __init__(self, naming: str, gps: Float, ifos: list, **kwargs): self.naming = ["azimuth", "zenith"] if len(ifos) < 2: - return ValueError("At least two detectors are needed to define the Earth frame") + return ValueError( + "At least two detectors are needed to define the Earth frame" + ) elif isinstance(ifos[0], str): self.ifos = [detector_preset[ifo] for ifo in ifos[:2]] elif isinstance(ifos[0], GroundBased2G): self.ifos = ifos[:2] else: - return ValueError("ifos should be a list of detector names or Detector objects") + return ValueError( + "ifos should be a list of detector names or Detector objects" + ) self.gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad self.delta_x = self.ifos[1].vertex - self.ifos[0].vertex - + self.transforms = { "azimuth": ( - "ra", lambda params: zenith_azimuth_to_ra_dec(params["zenith"], params["azimuth"], gmst=self.gmst, delta_x=self.delta_x)[0] + "ra", + lambda params: zenith_azimuth_to_ra_dec( + params["zenith"], + params["azimuth"], + gmst=self.gmst, + delta_x=self.delta_x, + )[0], ), "zenith": ( - "dec", lambda params: zenith_azimuth_to_ra_dec(params["zenith"], params["azimuth"], gmst=self.gmst, delta_x=self.delta_x)[1] + "dec", + lambda params: zenith_azimuth_to_ra_dec( + params["zenith"], + params["azimuth"], + gmst=self.gmst, + delta_x=self.delta_x, + )[1], ), } - def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> dict[str, Float[Array, " n_samples"]]: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: rng_keys = jax.random.split(rng_key, 2) - zenith = jnp.arccos(jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)) - azimuth = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi) + zenith = jnp.arccos( + jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) + ) + azimuth = jax.random.uniform( + rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi + ) return self.add_name(jnp.stack([azimuth, zenith], axis=1).T) def log_prob(self, x: dict[str, Float]) -> Float: - zenith = x['zenith'] - azimuth = x['azimuth'] + zenith = x["zenith"] + azimuth = x["azimuth"] output = jnp.where( (azimuth > 2 * jnp.pi) | (azimuth < 0) | (zenith > jnp.pi) | (zenith < 0), jnp.zeros_like(0) - jnp.inf, diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index a7ec5727..7d5e2b8b 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -118,6 +118,7 @@ def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: m2 = m1 * q return m1, m2 + @jit def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: """ @@ -156,7 +157,9 @@ def euler_rotation(delta_x: tuple[Float, Float, Float]): Copied and modified from bilby-cython/geometry.pyx """ - norm = jnp.power(delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5) + norm = jnp.power( + delta_x[0] * delta_x[0] + delta_x[1] * delta_x[1] + delta_x[2] * delta_x[2], 0.5 + ) cos_beta = delta_x[2] / norm sin_beta = jnp.power(1 - cos_beta**2, 0.5) @@ -168,15 +171,29 @@ def euler_rotation(delta_x: tuple[Float, Float, Float]): cos_gamma = jnp.cos(gamma) sin_gamma = jnp.sin(gamma) - rotation = jnp.array([[cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma, -sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma, sin_beta * cos_gamma], - [cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma, -sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma, sin_beta * sin_gamma], - [-cos_alpha * sin_beta, sin_alpha * sin_beta, cos_beta]]) + rotation = jnp.array( + [ + [ + cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma, + -sin_alpha * cos_beta * cos_gamma - cos_alpha * sin_gamma, + sin_beta * cos_gamma, + ], + [ + cos_alpha * cos_beta * sin_gamma + sin_alpha * cos_gamma, + -sin_alpha * cos_beta * sin_gamma + cos_alpha * cos_gamma, + sin_beta * sin_gamma, + ], + [-cos_alpha * sin_beta, sin_alpha * sin_beta, cos_beta], + ] + ) return rotation @jit -def zenith_azimuth_to_theta_phi(zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float]) -> tuple[Float, Float]: +def zenith_azimuth_to_theta_phi( + zenith: Float, azimuth: Float, delta_x: tuple[Float, Float, Float] +) -> tuple[Float, Float]: """ Transforming the azimuthal angle and zenith angle in Earth frame to the polar angle and azimuthal angle in sky frame. @@ -205,18 +222,23 @@ def zenith_azimuth_to_theta_phi(zenith: Float, azimuth: Float, delta_x: tuple[Fl rotation = euler_rotation(delta_x) - theta = jnp.acos(rotation[2][0] * sin_zenith * cos_azimuth + rotation[2][1] * sin_zenith * sin_azimuth + rotation[2][2] * cos_zenith) + theta = jnp.acos( + rotation[2][0] * sin_zenith * cos_azimuth + + rotation[2][1] * sin_zenith * sin_azimuth + + rotation[2][2] * cos_zenith + ) phi = jnp.fmod( - jnp.atan2( - rotation[1][0] * sin_zenith * cos_azimuth - + rotation[1][1] * sin_zenith * sin_azimuth - + rotation[1][2] * cos_zenith, - rotation[0][0] * sin_zenith * cos_azimuth - + rotation[0][1] * sin_zenith * sin_azimuth - + rotation[0][2] * cos_zenith) - + 2 * jnp.pi, - (2 * jnp.pi) - ) + jnp.atan2( + rotation[1][0] * sin_zenith * cos_azimuth + + rotation[1][1] * sin_zenith * sin_azimuth + + rotation[1][2] * cos_zenith, + rotation[0][0] * sin_zenith * cos_azimuth + + rotation[0][1] * sin_zenith * sin_azimuth + + rotation[0][2] * cos_zenith, + ) + + 2 * jnp.pi, + (2 * jnp.pi), + ) return theta, phi @@ -247,9 +269,11 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F @jit -def zenith_azimuth_to_ra_dec(zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float]) -> tuple[Float, Float]: +def zenith_azimuth_to_ra_dec( + zenith: Float, azimuth: Float, gmst: Float, delta_x: tuple[Float, Float, Float] +) -> tuple[Float, Float]: """ - Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. + Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. Parameters ---------- @@ -276,6 +300,7 @@ def zenith_azimuth_to_ra_dec(zenith: Float, azimuth: Float, gmst: Float, delta_x ra = ra % (2 * jnp.pi) return ra, dec + @jit def log_i0(x): """ From edf316a1b365746a449bb4d6a48901fd6ac29f75 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 10 Jul 2024 09:20:31 +0800 Subject: [PATCH 8/8] Reformat with Ruff --- src/jimgw/single_event/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 7d5e2b8b..72ef5e36 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -3,7 +3,6 @@ from jax.scipy.integrate import trapezoid from jax.scipy.special import i0e from jaxtyping import Array, Float -from jimgw.single_event.detector import GroundBased2G @jit