diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index 6cec947..98f7815 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -1,5 +1,6 @@ # This module implements various functions for the background COSMOLOGY import jax.numpy as np +from jax import lax import jax_cosmo.constants as const from jax_cosmo.scipy.interpolate import interp @@ -328,14 +329,23 @@ def transverse_comoving_distance(cosmo, a): \end{matrix} \right. """ - chi = radial_comoving_distance(cosmo, a) - if cosmo.k < 0: # Open universe + index = cosmo.k + 1 + + def open_universe(chi): return const.rh / cosmo.sqrtk * np.sinh(cosmo.sqrtk * chi / const.rh) - elif cosmo.k > 0: # Closed Universe - return const.rh / cosmo.sqrtk * np.sin(cosmo.sqrtk * chi / const.rh) - else: + + def flat_universe(chi): return chi + def close_universe(chi): + return const.rh / cosmo.sqrtk * np.sin(cosmo.sqrtk * chi / const.rh) + + branches = (open_universe, flat_universe, close_universe) + + chi = radial_comoving_distance(cosmo, a) + + return lax.switch(cosmo.k + 1, branches, chi) + def angular_diameter_distance(cosmo, a): r"""Angular diameter distance in [Mpc/h] for a given scale factor. diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 217d13c..4ff49d6 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -168,13 +168,7 @@ def Omega_k(self): @property def k(self): - if self.Omega > 1.0: # Closed universe - k = 1.0 - elif self.Omega == 1.0: # Flat universe - k = 0 - elif self.Omega < 1.0: # Open Universe - k = -1.0 - return k + return -np.sign(self._Omega_k).astype(np.int8) @property def sqrtk(self):