Skip to content

Commit

Permalink
Merge pull request #84 from eelregit/eelregit_transverse_distance_patch
Browse files Browse the repository at this point in the history
Make k and transverse distance jitable
  • Loading branch information
EiffL authored Jan 20, 2022
2 parents 96634cc + cf8500c commit d5c7464
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
20 changes: 15 additions & 5 deletions jax_cosmo/background.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# This module implements various functions for the background COSMOLOGY
import jax.numpy as np
from jax import lax

import jax_cosmo.constants as const
from jax_cosmo.scipy.interpolate import interp
Expand Down Expand Up @@ -328,14 +329,23 @@ def transverse_comoving_distance(cosmo, a):
\end{matrix}
\right.
"""
chi = radial_comoving_distance(cosmo, a)
if cosmo.k < 0: # Open universe
index = cosmo.k + 1

def open_universe(chi):
return const.rh / cosmo.sqrtk * np.sinh(cosmo.sqrtk * chi / const.rh)
elif cosmo.k > 0: # Closed Universe
return const.rh / cosmo.sqrtk * np.sin(cosmo.sqrtk * chi / const.rh)
else:

def flat_universe(chi):
return chi

def close_universe(chi):
return const.rh / cosmo.sqrtk * np.sin(cosmo.sqrtk * chi / const.rh)

branches = (open_universe, flat_universe, close_universe)

chi = radial_comoving_distance(cosmo, a)

return lax.switch(cosmo.k + 1, branches, chi)


def angular_diameter_distance(cosmo, a):
r"""Angular diameter distance in [Mpc/h] for a given scale factor.
Expand Down
8 changes: 1 addition & 7 deletions jax_cosmo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,7 @@ def Omega_k(self):

@property
def k(self):
if self.Omega > 1.0: # Closed universe
k = 1.0
elif self.Omega == 1.0: # Flat universe
k = 0
elif self.Omega < 1.0: # Open Universe
k = -1.0
return k
return -np.sign(self._Omega_k).astype(np.int8)

@property
def sqrtk(self):
Expand Down

0 comments on commit d5c7464

Please sign in to comment.