Skip to content

Commit

Permalink
Merge pull request #54 from DifferentiableUniverseInitiative/u/EiffL/…
Browse files Browse the repository at this point in the history
…spline_integration

Updates background to use new cubic splines
  • Loading branch information
EiffL authored Jul 18, 2020
2 parents 850fae8 + 1f6f413 commit d86bf6b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 39 deletions.
18 changes: 11 additions & 7 deletions jax_cosmo/angular_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import jax_cosmo.power as power
import jax_cosmo.transfer as tklib
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
from jax_cosmo.utils import a2z
from jax_cosmo.utils import z2a

Expand Down Expand Up @@ -54,7 +55,7 @@ def find_index(a, b):


def angular_cl(
cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit
cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.linear
):
"""
Computes angular Cls for the provided probes
Expand Down Expand Up @@ -93,13 +94,16 @@ def combine_kernels(inds):

# Now kernels has shape [ncls, na]
kernels = lax.map(combine_kernels, cl_index)

result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi ** 2, 1.0)
return result

# We transpose the result just to make sure that na is first
return result.T

return simps(integrand, z2a(zmax), 1.0, 512) / const.c ** 2
atab = np.linspace(z2a(zmax), 1.0, 64)
eval_integral = vmap(
lambda x: np.squeeze(
InterpolatedUnivariateSpline(atab, x).integral(z2a(zmax), 1.0)
)
)
return eval_integral(integrand(atab)) / const.c ** 2

return cl(ell)

Expand Down Expand Up @@ -161,7 +165,7 @@ def gaussian_cl_covariance_and_mean(
ell,
probes,
transfer_fn=tklib.Eisenstein_Hu,
nonlinear_fn=power.halofit,
nonlinear_fn=power.linear,
f_sky=0.25,
):
"""
Expand Down
31 changes: 18 additions & 13 deletions jax_cosmo/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as np

import jax_cosmo.constants as const
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
from jax_cosmo.scipy.ode import odeint

__all__ = [
Expand Down Expand Up @@ -202,7 +202,7 @@ def Omega_de_a(cosmo, a):
return cosmo.Omega_de * np.power(a, f_de(cosmo, a)) / Esqr(cosmo, a)


def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256):
def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=64):
r"""Radial comoving distance in [Mpc/h] for a given scale factor.
Parameters
Expand Down Expand Up @@ -235,17 +235,19 @@ def dchioverdlna(y, x):
return dchioverda(cosmo, xa) * xa

chitab = odeint(dchioverdlna, 0.0, np.log(atab))
# np.clip(- 3000*np.log(atab), 0, 10000)#odeint(dchioverdlna, 0., np.log(atab), cosmo)
chitab = chitab[-1] - chitab

cache = {"a": atab, "chi": chitab}
cache = {
"a2chi": InterpolatedUnivariateSpline(atab, chitab),
"chi2a": InterpolatedUnivariateSpline(chitab, atab),
}
cosmo._workspace["background.radial_comoving_distance"] = cache
else:
cache = cosmo._workspace["background.radial_comoving_distance"]

a = np.atleast_1d(a)
# Return the results as an interpolation of the table
return np.clip(interp(a, cache["a"], cache["chi"]), 0.0)
return np.clip(cache["a2chi"](a), 0.0)


def a_of_chi(cosmo, chi):
Expand All @@ -270,7 +272,7 @@ def a_of_chi(cosmo, chi):
radial_comoving_distance(cosmo, 1.0)
cache = cosmo._workspace["background.radial_comoving_distance"]
chi = np.atleast_1d(chi)
return interp(chi, cache["chi"], cache["a"])
return cache["chi2a"](chi)


def dchioverda(cosmo, a):
Expand Down Expand Up @@ -437,7 +439,7 @@ def growth_rate(cosmo, a):
return _growth_rate_ODE(cosmo, a)


def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=64):
""" Compute linear growth factor D(a) at a given scale factor,
normalised such that D(a=1) = 1.
Expand Down Expand Up @@ -478,11 +480,14 @@ def D_derivs(y, x):
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
ftab = y[:, 1] / y1[-1] * atab / gtab

cache = {"a": atab, "g": gtab, "f": ftab}
cache = {
"g": InterpolatedUnivariateSpline(atab, gtab),
"f": InterpolatedUnivariateSpline(atab, ftab),
}
cosmo._workspace["background.growth_factor"] = cache
else:
cache = cosmo._workspace["background.growth_factor"]
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
return np.clip(cache["g"](a), 0.0, 1.0)


def _growth_rate_ODE(cosmo, a):
Expand All @@ -506,10 +511,10 @@ def _growth_rate_ODE(cosmo, a):
if not "background.growth_factor" in cosmo._workspace.keys():
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
cache = cosmo._workspace["background.growth_factor"]
return interp(a, cache["a"], cache["f"])
return cache["f"](a)


def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=64):
r""" Computes growth factor by integrating the growth rate provided by the
\gamma parametrization. Normalized such that D( a=1) =1
Expand Down Expand Up @@ -538,11 +543,11 @@ def integrand(y, loga):

gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
gtab = gtab / gtab[-1] # Normalize to a=1.
cache = {"a": atab, "g": gtab}
cache = {"g": InterpolatedUnivariateSpline(atab, gtab)}
cosmo._workspace["background.growth_factor"] = cache
else:
cache = cosmo._workspace["background.growth_factor"]
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
return np.clip(cache["g"](a), 0.0, 1.0)


def _growth_rate_gamma(cosmo, a):
Expand Down
16 changes: 8 additions & 8 deletions tests/test_angular_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_lensing_cl():
n_s=0.96,
Neff=0,
transfer_function="eisenstein_hu",
matter_power_spectrum="halofit",
matter_power_spectrum="linear",
)

cosmo_jax = Cosmology(
Expand All @@ -43,13 +43,13 @@ def test_lensing_cl():
tracer_jax = probes.WeakLensing([nz])

# Get an ell range for the cls
ell = np.logspace(0.1, 4)
ell = np.logspace(1, 4)

# Compute the cls
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)
cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax])

assert_allclose(cl_ccl, cl_jax[0], rtol=1.0e-2)
assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2)


def test_lensing_cl_IA():
Expand All @@ -62,7 +62,7 @@ def test_lensing_cl_IA():
n_s=0.96,
Neff=0,
transfer_function="eisenstein_hu",
matter_power_spectrum="halofit",
matter_power_spectrum="linear",
)

cosmo_jax = Cosmology(
Expand Down Expand Up @@ -90,13 +90,13 @@ def test_lensing_cl_IA():
tracer_jax = probes.WeakLensing([nz], bias)

# Get an ell range for the cls
ell = np.logspace(0.1, 4)
ell = np.logspace(1, 4)

# Compute the cls
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)
cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax])

assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2)
assert_allclose(cl_ccl, cl_jax[0], rtol=5e-3)


def test_clustering_cl():
Expand All @@ -109,7 +109,7 @@ def test_clustering_cl():
n_s=0.96,
Neff=0,
transfer_function="eisenstein_hu",
matter_power_spectrum="halofit",
matter_power_spectrum="linear",
)

cosmo_jax = Cosmology(
Expand All @@ -136,7 +136,7 @@ def test_clustering_cl():
tracer_jax = probes.NumberCounts([nz], bias)

# Get an ell range for the cls
ell = np.logspace(0.1, 4)
ell = np.logspace(1, 4)

# Compute the cls
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)
Expand Down
22 changes: 11 additions & 11 deletions tests/test_background.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def test_distances_flat():

chi_ccl = ccl.comoving_radial_distance(cosmo_ccl, a)
chi_jax = bkgrd.radial_comoving_distance(cosmo_jax, a) / cosmo_jax.h
assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2)
assert_allclose(chi_ccl, chi_jax, rtol=1e-3)

chi_ccl = ccl.comoving_angular_distance(cosmo_ccl, a)
chi_jax = bkgrd.transverse_comoving_distance(cosmo_jax, a) / cosmo_jax.h
assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2)
assert_allclose(chi_ccl, chi_jax, rtol=1e-3)

chi_ccl = ccl.angular_diameter_distance(cosmo_ccl, a)
chi_jax = bkgrd.angular_diameter_distance(cosmo_jax, a) / cosmo_jax.h
assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2)
assert_allclose(chi_ccl, chi_jax, rtol=1e-3)


def test_growth():
Expand Down Expand Up @@ -72,12 +72,12 @@ def test_growth():
)

# Test array of scale factors
a = np.linspace(0.01, 1.0)
a = np.linspace(0.1, 1.0)

gccl = ccl.growth_factor(cosmo_ccl, a)
gjax = bkgrd.growth_factor(cosmo_jax, a)

assert_allclose(gccl, gjax, rtol=1e-2)
assert_allclose(gccl, gjax, rtol=1e-3)


def test_growth_rate():
Expand Down Expand Up @@ -105,12 +105,12 @@ def test_growth_rate():
)

# Test array of scale factors
a = np.linspace(0.01, 1.0)
a = np.linspace(0.1, 1.0)

fccl = ccl.growth_rate(cosmo_ccl, a)
fjax = bkgrd.growth_rate(cosmo_jax, a)

assert_allclose(fccl, fjax, rtol=1e-2)
assert_allclose(fccl, fjax, rtol=1e-3)


def test_growth_rate_gamma():
Expand Down Expand Up @@ -140,12 +140,12 @@ def test_growth_rate_gamma():
)

# Test array of scale factors
a = np.linspace(0.01, 1.0)
a = np.linspace(0.1, 1.0)

fccl = ccl.growth_rate(cosmo_ccl, a)
fjax = bkgrd.growth_rate(cosmo_jax, a)

assert_allclose(fccl, fjax, rtol=1e-2)
assert_allclose(fccl, fjax, rtol=5e-3)


def test_growth_gamma():
Expand Down Expand Up @@ -174,9 +174,9 @@ def test_growth_gamma():
)

# Test array of scale factors
a = np.linspace(0.01, 1.0)
a = np.linspace(0.1, 1.0)

gccl = ccl.growth_factor(cosmo_ccl, a)
gjax = bkgrd.growth_factor(cosmo_jax, a)

assert_allclose(gccl, gjax, rtol=1e-2)
assert_allclose(gccl, gjax, rtol=1e-3)

0 comments on commit d86bf6b

Please sign in to comment.