diff --git a/jax_cosmo/angular_cl.py b/jax_cosmo/angular_cl.py index 63db4dd..62a5721 100644 --- a/jax_cosmo/angular_cl.py +++ b/jax_cosmo/angular_cl.py @@ -15,6 +15,7 @@ import jax_cosmo.power as power import jax_cosmo.transfer as tklib from jax_cosmo.scipy.integrate import simps +from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline from jax_cosmo.utils import a2z from jax_cosmo.utils import z2a @@ -54,7 +55,7 @@ def find_index(a, b): def angular_cl( - cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit + cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.linear ): """ Computes angular Cls for the provided probes @@ -93,13 +94,16 @@ def combine_kernels(inds): # Now kernels has shape [ncls, na] kernels = lax.map(combine_kernels, cl_index) - result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi ** 2, 1.0) + return result - # We transpose the result just to make sure that na is first - return result.T - - return simps(integrand, z2a(zmax), 1.0, 512) / const.c ** 2 + atab = np.linspace(z2a(zmax), 1.0, 64) + eval_integral = vmap( + lambda x: np.squeeze( + InterpolatedUnivariateSpline(atab, x).integral(z2a(zmax), 1.0) + ) + ) + return eval_integral(integrand(atab)) / const.c ** 2 return cl(ell) @@ -161,7 +165,7 @@ def gaussian_cl_covariance_and_mean( ell, probes, transfer_fn=tklib.Eisenstein_Hu, - nonlinear_fn=power.halofit, + nonlinear_fn=power.linear, f_sky=0.25, ): """ diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index aa0481c..7c15f7b 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -6,7 +6,7 @@ import jax.numpy as np import jax_cosmo.constants as const -from jax_cosmo.scipy.interpolate import interp +from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline from jax_cosmo.scipy.ode import odeint __all__ = [ @@ -202,7 +202,7 @@ def Omega_de_a(cosmo, a): return cosmo.Omega_de * np.power(a, f_de(cosmo, a)) / Esqr(cosmo, a) -def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256): +def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=64): r"""Radial comoving distance in [Mpc/h] for a given scale factor. Parameters @@ -235,17 +235,19 @@ def dchioverdlna(y, x): return dchioverda(cosmo, xa) * xa chitab = odeint(dchioverdlna, 0.0, np.log(atab)) - # np.clip(- 3000*np.log(atab), 0, 10000)#odeint(dchioverdlna, 0., np.log(atab), cosmo) chitab = chitab[-1] - chitab - cache = {"a": atab, "chi": chitab} + cache = { + "a2chi": InterpolatedUnivariateSpline(atab, chitab), + "chi2a": InterpolatedUnivariateSpline(chitab, atab), + } cosmo._workspace["background.radial_comoving_distance"] = cache else: cache = cosmo._workspace["background.radial_comoving_distance"] a = np.atleast_1d(a) # Return the results as an interpolation of the table - return np.clip(interp(a, cache["a"], cache["chi"]), 0.0) + return np.clip(cache["a2chi"](a), 0.0) def a_of_chi(cosmo, chi): @@ -270,7 +272,7 @@ def a_of_chi(cosmo, chi): radial_comoving_distance(cosmo, 1.0) cache = cosmo._workspace["background.radial_comoving_distance"] chi = np.atleast_1d(chi) - return interp(chi, cache["chi"], cache["a"]) + return cache["chi2a"](chi) def dchioverda(cosmo, a): @@ -437,7 +439,7 @@ def growth_rate(cosmo, a): return _growth_rate_ODE(cosmo, a) -def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): +def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=64): """ Compute linear growth factor D(a) at a given scale factor, normalised such that D(a=1) = 1. @@ -478,11 +480,14 @@ def D_derivs(y, x): # To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da ftab = y[:, 1] / y1[-1] * atab / gtab - cache = {"a": atab, "g": gtab, "f": ftab} + cache = { + "g": InterpolatedUnivariateSpline(atab, gtab), + "f": InterpolatedUnivariateSpline(atab, ftab), + } cosmo._workspace["background.growth_factor"] = cache else: cache = cosmo._workspace["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + return np.clip(cache["g"](a), 0.0, 1.0) def _growth_rate_ODE(cosmo, a): @@ -506,10 +511,10 @@ def _growth_rate_ODE(cosmo, a): if not "background.growth_factor" in cosmo._workspace.keys(): _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) cache = cosmo._workspace["background.growth_factor"] - return interp(a, cache["a"], cache["f"]) + return cache["f"](a) -def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): +def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=64): r""" Computes growth factor by integrating the growth rate provided by the \gamma parametrization. Normalized such that D( a=1) =1 @@ -538,11 +543,11 @@ def integrand(y, loga): gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab))) gtab = gtab / gtab[-1] # Normalize to a=1. - cache = {"a": atab, "g": gtab} + cache = {"g": InterpolatedUnivariateSpline(atab, gtab)} cosmo._workspace["background.growth_factor"] = cache else: cache = cosmo._workspace["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + return np.clip(cache["g"](a), 0.0, 1.0) def _growth_rate_gamma(cosmo, a): diff --git a/tests/test_angular_cl.py b/tests/test_angular_cl.py index 96a2704..53a1a6d 100644 --- a/tests/test_angular_cl.py +++ b/tests/test_angular_cl.py @@ -22,7 +22,7 @@ def test_lensing_cl(): n_s=0.96, Neff=0, transfer_function="eisenstein_hu", - matter_power_spectrum="halofit", + matter_power_spectrum="linear", ) cosmo_jax = Cosmology( @@ -43,13 +43,13 @@ def test_lensing_cl(): tracer_jax = probes.WeakLensing([nz]) # Get an ell range for the cls - ell = np.logspace(0.1, 4) + ell = np.logspace(1, 4) # Compute the cls cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell) cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax]) - assert_allclose(cl_ccl, cl_jax[0], rtol=1.0e-2) + assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2) def test_lensing_cl_IA(): @@ -62,7 +62,7 @@ def test_lensing_cl_IA(): n_s=0.96, Neff=0, transfer_function="eisenstein_hu", - matter_power_spectrum="halofit", + matter_power_spectrum="linear", ) cosmo_jax = Cosmology( @@ -90,13 +90,13 @@ def test_lensing_cl_IA(): tracer_jax = probes.WeakLensing([nz], bias) # Get an ell range for the cls - ell = np.logspace(0.1, 4) + ell = np.logspace(1, 4) # Compute the cls cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell) cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax]) - assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2) + assert_allclose(cl_ccl, cl_jax[0], rtol=5e-3) def test_clustering_cl(): @@ -109,7 +109,7 @@ def test_clustering_cl(): n_s=0.96, Neff=0, transfer_function="eisenstein_hu", - matter_power_spectrum="halofit", + matter_power_spectrum="linear", ) cosmo_jax = Cosmology( @@ -136,7 +136,7 @@ def test_clustering_cl(): tracer_jax = probes.NumberCounts([nz], bias) # Get an ell range for the cls - ell = np.logspace(0.1, 4) + ell = np.logspace(1, 4) # Compute the cls cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell) diff --git a/tests/test_background.py b/tests/test_background.py index 4e973a2..35e05eb 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -36,15 +36,15 @@ def test_distances_flat(): chi_ccl = ccl.comoving_radial_distance(cosmo_ccl, a) chi_jax = bkgrd.radial_comoving_distance(cosmo_jax, a) / cosmo_jax.h - assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) + assert_allclose(chi_ccl, chi_jax, rtol=1e-3) chi_ccl = ccl.comoving_angular_distance(cosmo_ccl, a) chi_jax = bkgrd.transverse_comoving_distance(cosmo_jax, a) / cosmo_jax.h - assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) + assert_allclose(chi_ccl, chi_jax, rtol=1e-3) chi_ccl = ccl.angular_diameter_distance(cosmo_ccl, a) chi_jax = bkgrd.angular_diameter_distance(cosmo_jax, a) / cosmo_jax.h - assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) + assert_allclose(chi_ccl, chi_jax, rtol=1e-3) def test_growth(): @@ -72,12 +72,12 @@ def test_growth(): ) # Test array of scale factors - a = np.linspace(0.01, 1.0) + a = np.linspace(0.1, 1.0) gccl = ccl.growth_factor(cosmo_ccl, a) gjax = bkgrd.growth_factor(cosmo_jax, a) - assert_allclose(gccl, gjax, rtol=1e-2) + assert_allclose(gccl, gjax, rtol=1e-3) def test_growth_rate(): @@ -105,12 +105,12 @@ def test_growth_rate(): ) # Test array of scale factors - a = np.linspace(0.01, 1.0) + a = np.linspace(0.1, 1.0) fccl = ccl.growth_rate(cosmo_ccl, a) fjax = bkgrd.growth_rate(cosmo_jax, a) - assert_allclose(fccl, fjax, rtol=1e-2) + assert_allclose(fccl, fjax, rtol=1e-3) def test_growth_rate_gamma(): @@ -140,12 +140,12 @@ def test_growth_rate_gamma(): ) # Test array of scale factors - a = np.linspace(0.01, 1.0) + a = np.linspace(0.1, 1.0) fccl = ccl.growth_rate(cosmo_ccl, a) fjax = bkgrd.growth_rate(cosmo_jax, a) - assert_allclose(fccl, fjax, rtol=1e-2) + assert_allclose(fccl, fjax, rtol=5e-3) def test_growth_gamma(): @@ -174,9 +174,9 @@ def test_growth_gamma(): ) # Test array of scale factors - a = np.linspace(0.01, 1.0) + a = np.linspace(0.1, 1.0) gccl = ccl.growth_factor(cosmo_ccl, a) gjax = bkgrd.growth_factor(cosmo_jax, a) - assert_allclose(gccl, gjax, rtol=1e-2) + assert_allclose(gccl, gjax, rtol=1e-3)