diff --git a/jax_cosmo/angular_cl.py b/jax_cosmo/angular_cl.py index 63db4dd..62a5721 100644 --- a/jax_cosmo/angular_cl.py +++ b/jax_cosmo/angular_cl.py @@ -15,6 +15,7 @@ import jax_cosmo.power as power import jax_cosmo.transfer as tklib from jax_cosmo.scipy.integrate import simps +from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline from jax_cosmo.utils import a2z from jax_cosmo.utils import z2a @@ -54,7 +55,7 @@ def find_index(a, b): def angular_cl( - cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit + cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.linear ): """ Computes angular Cls for the provided probes @@ -93,13 +94,16 @@ def combine_kernels(inds): # Now kernels has shape [ncls, na] kernels = lax.map(combine_kernels, cl_index) - result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi ** 2, 1.0) + return result - # We transpose the result just to make sure that na is first - return result.T - - return simps(integrand, z2a(zmax), 1.0, 512) / const.c ** 2 + atab = np.linspace(z2a(zmax), 1.0, 64) + eval_integral = vmap( + lambda x: np.squeeze( + InterpolatedUnivariateSpline(atab, x).integral(z2a(zmax), 1.0) + ) + ) + return eval_integral(integrand(atab)) / const.c ** 2 return cl(ell) @@ -161,7 +165,7 @@ def gaussian_cl_covariance_and_mean( ell, probes, transfer_fn=tklib.Eisenstein_Hu, - nonlinear_fn=power.halofit, + nonlinear_fn=power.linear, f_sky=0.25, ): """ diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index 66a2d94..7c15f7b 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -439,7 +439,7 @@ def growth_rate(cosmo, a): return _growth_rate_ODE(cosmo, a) -def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=32): +def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=64): """ Compute linear growth factor D(a) at a given scale factor, normalised such that D(a=1) = 1. @@ -514,7 +514,7 @@ def _growth_rate_ODE(cosmo, a): return cache["f"](a) -def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=32): +def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=64): r""" Computes growth factor by integrating the growth rate provided by the \gamma parametrization. Normalized such that D( a=1) =1 diff --git a/tests/test_angular_cl.py b/tests/test_angular_cl.py index 44af0b1..1b96d75 100644 --- a/tests/test_angular_cl.py +++ b/tests/test_angular_cl.py @@ -22,7 +22,7 @@ def test_lensing_cl(): n_s=0.96, Neff=0, transfer_function="eisenstein_hu", - matter_power_spectrum="halofit", + matter_power_spectrum="linear", ) cosmo_jax = Cosmology( @@ -43,13 +43,13 @@ def test_lensing_cl(): tracer_jax = probes.WeakLensing([nz]) # Get an ell range for the cls - ell = np.logspace(0.1, 4) + ell = np.logspace(1, 4) # Compute the cls cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell) cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax]) - assert_allclose(cl_ccl, cl_jax[0], rtol=0.5e-2) + assert_allclose(cl_ccl, cl_jax[0], rtol=5e-3) def test_lensing_cl_IA(): @@ -62,7 +62,7 @@ def test_lensing_cl_IA(): n_s=0.96, Neff=0, transfer_function="eisenstein_hu", - matter_power_spectrum="halofit", + matter_power_spectrum="linear", ) cosmo_jax = Cosmology( @@ -90,13 +90,13 @@ def test_lensing_cl_IA(): tracer_jax = probes.WeakLensing([nz], bias) # Get an ell range for the cls - ell = np.logspace(0.1, 4) + ell = np.logspace(1, 4) # Compute the cls cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell) cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax]) - assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2) + assert_allclose(cl_ccl, cl_jax[0], rtol=5e-3) def test_clustering_cl(): @@ -109,7 +109,7 @@ def test_clustering_cl(): n_s=0.96, Neff=0, transfer_function="eisenstein_hu", - matter_power_spectrum="halofit", + matter_power_spectrum="linear", ) cosmo_jax = Cosmology( @@ -136,10 +136,10 @@ def test_clustering_cl(): tracer_jax = probes.NumberCounts([nz], bias) # Get an ell range for the cls - ell = np.logspace(0.1, 4) + ell = np.logspace(1, 4) # Compute the cls cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell) cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax]) - assert_allclose(cl_ccl, cl_jax[0], rtol=0.5e-2) + assert_allclose(cl_ccl, cl_jax[0], rtol=5e-3)