Skip to content

Commit

Permalink
Merge pull request #70 from DifferentiableUniverseInitiative/single_s…
Browse files Browse the repository at this point in the history
…ource

Single source
  • Loading branch information
EiffL authored May 18, 2021
2 parents 63ccb61 + 5857ca7 commit 9b7e3c4
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 9 deletions.
60 changes: 52 additions & 8 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jax_cosmo.background as bkgrd
import jax_cosmo.constants as const
import jax_cosmo.redshift as rds
from jax_cosmo.jax_utils import container
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.utils import a2z
Expand All @@ -18,21 +19,56 @@
def weak_lensing_kernel(cosmo, pzs, z, ell):
"""
Returns a weak lensing kernel
Note: this function handles differently nzs that correspond to extended redshift
distribution, and delta functions.
"""
z = np.atleast_1d(z)
zmax = max([pz.zmax for pz in pzs])
# Retrieve comoving distance corresponding to z
chi = bkgrd.radial_comoving_distance(cosmo, z2a(z))

@vmap
def integrand(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
# Stack the dndz of all redshift bins
dndz = np.stack([pz(z_prime) for pz in pzs], axis=0)
return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0)
# Extract the indices of pzs that can be treated as extended distributions,
# and the ones that need to be treated as delta functions.
pzs_extended_idx = [
i for i, pz in enumerate(pzs) if not isinstance(pz, rds.delta_nz)
]
pzs_delta_idx = [i for i, pz in enumerate(pzs) if isinstance(pz, rds.delta_nz)]
# Here we define a permutation that would put all extended pzs at the begining of the list
perm = pzs_extended_idx + pzs_delta_idx
# Compute inverse permutation
inv = np.argsort(np.array(perm, dtype=np.int32))

# Process extended distributions, if any
radial_kernels = []
if len(pzs_extended_idx) > 0:

@vmap
def integrand(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
# Stack the dndz of all redshift bins
dndz = np.stack([pzs[i](z_prime) for i in pzs_extended_idx], axis=0)
return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0)

radial_kernels.append(simps(integrand, z, zmax, 256) * (1.0 + z) * chi)
# Process single plane redshifts if any
if len(pzs_delta_idx) > 0:

@vmap
def integrand_single(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
return np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0)

radial_kernels.append(
integrand_single(np.array([pzs[i].params[0] for i in pzs_delta_idx]))
* (1.0 + z)
* chi
)
# Fusing the results together
radial_kernel = np.concatenate(radial_kernels, axis=0)
# And perfoming inverse permutation to put all the indices where they should be
radial_kernel = radial_kernel[inv]

# Computes the radial weak lensing kernel
radial_kernel = np.squeeze(simps(integrand, z, zmax, 256) * (1.0 + z) * chi)
# Constant term
constant_factor = 3.0 * const.H0 ** 2 * cosmo.Omega_m / 2.0 / const.c
# Ell dependent factor
Expand All @@ -45,6 +81,10 @@ def density_kernel(cosmo, pzs, bias, z, ell):
"""
Computes the number counts density kernel
"""
if any(isinstance(pz, rds.delta_nz) for pz in pzs):
raise NotImplementedError(
"Density kernel not properly implemented for delta redshift distributions"
)
# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)
# Compute radial NLA kernel: same as clustering
Expand All @@ -66,6 +106,10 @@ def nla_kernel(cosmo, pzs, bias, z, ell):
"""
Computes the NLA IA kernel
"""
if any(isinstance(pz, rds.delta_nz) for pz in pzs):
raise NotImplementedError(
"NLA kernel not properly implemented for delta redshift distributions"
)
# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)
# Compute radial NLA kernel: same as clustering
Expand Down
20 changes: 19 additions & 1 deletion jax_cosmo/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

steradian_to_arcmin2 = 11818102.86004228

__all__ = ["smail_nz", "kde_nz"]
__all__ = ["smail_nz", "kde_nz", "delta_nz"]


class redshift_distribution(container):
Expand Down Expand Up @@ -78,6 +78,24 @@ def pz_fn(self, z):
return z ** a * np.exp(-((z / z0) ** b))


@register_pytree_node_class
class delta_nz(redshift_distribution):
"""Defines a single plane redshift distribution with these arguments
Parameters:
-----------
z0:
"""

def __init__(self, *args, **kwargs):
"""Initialize the parameters of the redshift distribution"""
super(delta_nz, self).__init__(*args, **kwargs)
self._norm = 1.0

def pz_fn(self, z):
z0 = self.params
return np.where(z == z0, 1.0, 0)


@register_pytree_node_class
class kde_nz(redshift_distribution):
"""A redshift distribution based on a KDE estimate of the nz of a
Expand Down
54 changes: 54 additions & 0 deletions tests/test_angular_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from jax_cosmo.angular_cl import gaussian_cl_covariance
from jax_cosmo.bias import constant_linear_bias
from jax_cosmo.bias import inverse_growth_linear_bias
from jax_cosmo.redshift import delta_nz
from jax_cosmo.redshift import kde_nz
from jax_cosmo.redshift import smail_nz
from jax_cosmo.sparse import to_dense

Expand Down Expand Up @@ -55,6 +57,58 @@ def test_lensing_cl():
assert_allclose(cl_ccl, cl_jax[0], rtol=1.0e-2)


def test_lensing_cl_delta():
# We first define equivalent CCL and jax_cosmo cosmologies
cosmo_ccl = ccl.Cosmology(
Omega_c=0.3,
Omega_b=0.05,
h=0.7,
sigma8=0.8,
n_s=0.96,
Neff=0,
transfer_function="eisenstein_hu",
matter_power_spectrum="halofit",
)

cosmo_jax = Cosmology(
Omega_c=0.3,
Omega_b=0.05,
h=0.7,
sigma8=0.8,
n_s=0.96,
Omega_k=0.0,
w0=-1.0,
wa=0.0,
)

# Define a redshift distribution
z0 = 1.0
z = np.linspace(0, 5.0, 1024)
pz = np.zeros_like(z)
pz[np.argmin(abs(z0 - z))] = 1.0
nzs_s = kde_nz(z, pz, bw=0.01)
nz = delta_nz(z0)
nz_smail1 = smail_nz(1.0, 2.0, 1.0)
nz_smail2 = smail_nz(1.4, 2.0, 1.0)
tracer_ccl = ccl.WeakLensingTracer(cosmo_ccl, (z, nzs_s(z)), use_A_ia=False)
tracer_cclb = ccl.WeakLensingTracer(cosmo_ccl, (z, nz_smail2(z)), use_A_ia=False)
tracer_jax = probes.WeakLensing([nz])
tracer_jaxb = probes.WeakLensing([nz, nz_smail1, nz_smail2])

# Get an ell range for the cls
ell = np.logspace(0.1, 4)

# Compute the cls
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)
cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax])
assert_allclose(cl_ccl, cl_jax[0], rtol=1.0e-2)

# Also test if several nzs are provided
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_cclb, tracer_cclb, ell)
cl_jax = angular_cl(cosmo_jax, ell, [tracer_jaxb])
assert_allclose(cl_ccl, cl_jax[-1], rtol=1.0e-2)


def test_lensing_cl_IA():
# We first define equivalent CCL and jax_cosmo cosmologies
cosmo_ccl = ccl.Cosmology(
Expand Down

0 comments on commit 9b7e3c4

Please sign in to comment.