From ae3cefdccb0859c80d8ec19fa48a178ecd9977f1 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 19 Jan 2022 16:58:11 -0500 Subject: [PATCH 1/3] Change Cosmology.k to int type consistently --- jax_cosmo/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 217d13c..4c2ae05 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -169,11 +169,11 @@ def Omega_k(self): @property def k(self): if self.Omega > 1.0: # Closed universe - k = 1.0 + k = 1 elif self.Omega == 1.0: # Flat universe k = 0 elif self.Omega < 1.0: # Open Universe - k = -1.0 + k = -1 return k @property From 3e45de09950806b440bab58cef853376504fb162 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 19 Jan 2022 19:51:34 -0500 Subject: [PATCH 2/3] Make transverse comoving distance jitable --- jax_cosmo/background.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index 6cec947..98f7815 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -1,5 +1,6 @@ # This module implements various functions for the background COSMOLOGY import jax.numpy as np +from jax import lax import jax_cosmo.constants as const from jax_cosmo.scipy.interpolate import interp @@ -328,14 +329,23 @@ def transverse_comoving_distance(cosmo, a): \end{matrix} \right. """ - chi = radial_comoving_distance(cosmo, a) - if cosmo.k < 0: # Open universe + index = cosmo.k + 1 + + def open_universe(chi): return const.rh / cosmo.sqrtk * np.sinh(cosmo.sqrtk * chi / const.rh) - elif cosmo.k > 0: # Closed Universe - return const.rh / cosmo.sqrtk * np.sin(cosmo.sqrtk * chi / const.rh) - else: + + def flat_universe(chi): return chi + def close_universe(chi): + return const.rh / cosmo.sqrtk * np.sin(cosmo.sqrtk * chi / const.rh) + + branches = (open_universe, flat_universe, close_universe) + + chi = radial_comoving_distance(cosmo, a) + + return lax.switch(cosmo.k + 1, branches, chi) + def angular_diameter_distance(cosmo, a): r"""Angular diameter distance in [Mpc/h] for a given scale factor. From cf8500c2f0e174c5b132d2c53400150194dffb70 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 19 Jan 2022 20:07:18 -0500 Subject: [PATCH 3/3] Make Cosmology.k jitable --- jax_cosmo/core.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 4c2ae05..4ff49d6 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -168,13 +168,7 @@ def Omega_k(self): @property def k(self): - if self.Omega > 1.0: # Closed universe - k = 1 - elif self.Omega == 1.0: # Flat universe - k = 0 - elif self.Omega < 1.0: # Open Universe - k = -1 - return k + return -np.sign(self._Omega_k).astype(np.int8) @property def sqrtk(self):