Skip to content

Commit

Permalink
Merge pull request #17 from DifferentiableUniverseInitiative/number_c…
Browse files Browse the repository at this point in the history
…ounts

Adds clustering probe
  • Loading branch information
EiffL authored May 10, 2020
2 parents a879a8e + 2780496 commit ea3ccf3
Show file tree
Hide file tree
Showing 9 changed files with 1,150 additions and 50 deletions.
80 changes: 77 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,81 @@
# jax_cosmo
# jax-cosmo

[![Join the chat at https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo](https://badges.gitter.im/DifferentiableUniverseInitiative/jax_cosmo.svg)](https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ![Python package](https://github.com/DifferentiableUniverseInitiative/jax_cosmo/workflows/Python%20package/badge.svg)

A differentiable cosmology library in JAX
A differentiable cosmology library in JAX.

This reposotory is a prototype for a differentiable cosmology library written in Jax. For this prototype we are going to keep things simple, and port as much exciting code as possible from a pure Python package: https://github.com/cosmicpy/cosmicpy
**Note**: This package is still in the development phase, expect changes to the API. We hope to make this project a community effort, contributions of all kind are most welcome!
Have a look at the [GitHub issues](https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues) to see what is needed or if you have any thoughts on the design, and don't hesitate to join the [Gitter room](https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo) for discussions.

## TL;DR

This is what `jax-cosmo` aims to do:

```python
data = #... some measured Cl data vector
nz1,nz2,nz3,nz4 = #.... redshift distributions of bins
def likelihood(cosmo):
# Define a list of probes
probes = [jax_cosmo.probes.WeakLensing([nz1, nz2, nz3, nz4]),
jax_cosmo.probes.NumberCounts([nz1, nz2, nz3, nz4])]

# Compute mean and covariance of angular Cls
mu, cov = jax_cosmo.angular_cl.gaussian_cl_covariance(cosmo, ell, probes)

# Return likelihood value
return jax_cosmo.likelihood.gaussian_log_likelihood(data, mu, cov)

# Compute derivatives of the likelihood with respect to cosmological parameters
g = jax.grad(likelihood)(cosmo)

# Compute Fisher matrix of cosmological parameters
F = - jax.hessian(likelihood)(cosmo)
```
This is how you can compute gradients and hessians of any functions in `jax-cosmo`,
all of this without any finite differences.

Check out a full example here:

## What is JAX?

[JAX](https://github.com/google/jax) = NumPy + autodiff + GPU

JAX is a framework for automatic differentiation (like TensorFlow or PyTorch) but following the NumPy API, and using the GPU/TPU enable XLA backend.

What does that mean?
- You write plain Python/NumPy code, no need to learn a different language
- It runs on GPU, you don't need to do anything particular
- You can take derivatives of any quantity with respect to any parameters by
automatic differentiation.

Checkout the [JAX](https://github.com/google/jax) project page to learn more!

## Install

`jax-cosmo` is pure Python, so installing is a breeze:
```bash
$ pip install jax-cosmo
```

## Philosophy

Here are some of the design guidelines:
- Implementation of equations should be human readable, and documentation should always live next to the implementation.
- Should always be trivially installable: external dependencies should be kept
to a minimum, especially the ones that require compilation or with restrictive licenses.
- Keep API and implementation simple and intuitive, minimize user and developer
surprise.
- “Debugging is twice as hard as writing the code in the first place. Therefore, if you write the code as cleverly as possible, you are, by definition, not smart enough to debug it.” -Brian Kernighan, quote stolen from
[here](https://flax.readthedocs.io/en/latest/philosophy.html).

## Contributing

`jax-cosmo` aims to be a community effort, contributions are most welcome and
can come in several forms
- Bug reports
- API design suggestions
- (Pull) requests for more features
- Examples and notebooks of cool things that can be done with the code

The issue page is a good place to start, but don't hesitate to come chat in the
Gitter room.
5 changes: 5 additions & 0 deletions jax_cosmo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@
from jax_cosmo.parameters import *
import jax_cosmo.background as background
import jax_cosmo.power as power
import jax_cosmo.redshift as redshift
import jax_cosmo.angular_cl as cl
import jax_cosmo.bias as bias
import jax_cosmo.probes as probes
import jax_cosmo.likelihood as likelihood
10 changes: 6 additions & 4 deletions jax_cosmo/angular_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax.numpy as np
from jax import vmap, lax, jit

import jax_cosmo.constants as const
from jax_cosmo.utils import z2a, a2z
from jax_cosmo.scipy.integrate import simps
import jax_cosmo.background as bkgrd
Expand Down Expand Up @@ -73,18 +74,19 @@ def integrand(a):

# Define an ordering for the blocks of the signal vector
cl_index = np.array(_get_cl_ordering(probes))

# Compute all combinations of tracers
@jit
def combine_kernels(inds):
return kernels[inds[0]] * kernels[inds[1]]

# Now kernels has shape [ncls, na]
kernels = lax.map(combine_kernels, cl_index)

result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi**2, 1.)

# We transpose the result just to make sure that na is first
return (pk * kernels * bkgrd.dchioverda(cosmo, a)/a**2).T
return result.T

return simps(integrand, amin, 1., 512)
return simps(integrand, amin, 1., 512) / const.c**2

def noise_cl(ell, probes):
"""
Expand Down
20 changes: 20 additions & 0 deletions jax_cosmo/bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This module contains implementations of galaxy bias
import jax.numpy as np
from jax.tree_util import register_pytree_node_class

from jax_cosmo.jax_utils import container

@register_pytree_node_class
class constant_linear_bias(container):
"""
Class representing a linear bias
Parameters:
-----------
b: redshift independent bias value
"""
def __call__(self, z):
"""
"""
b = self.params[0]
return b * np.ones_like(z)
98 changes: 80 additions & 18 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax_cosmo.jax_utils import container
from jax.tree_util import register_pytree_node_class

__all__ = ["WeakLensing"]
__all__ = ["WeakLensing", "NumberCounts"]

@register_pytree_node_class
class WeakLensing(container):
Expand All @@ -21,16 +21,18 @@ class WeakLensing(container):
-----------
redshift_bins: nzredshift distributions
sigma_e: intrinsic galaxy ellipticity
n_eff: effective number density per bins [1./arcmin^2]
Configuration:
--------------
has_shear:, ia_bias, use_bias.... use_shear is not functional
sigma_e: intrinsic galaxy ellipticity
has_shear:, ia_bias, use_bias.... these are not functional
"""
def __init__(self, redshift_bins,
sigma_e=0.26,
use_shear=True,
**kwargs):
super(WeakLensing, self).__init__(redshift_bins,
sigma_e=sigma_e,
use_shear=use_shear,
**kwargs)
@property
Expand All @@ -43,7 +45,7 @@ def n_tracers(self):
return len(pzs)

def constant_factor(self, cosmo):
return 3.0 * const.H0**2 * cosmo.Omega_m / (2.0 * const.c**2)
return 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c

@jit
def radial_kernel(self, cosmo, z):
Expand All @@ -68,8 +70,11 @@ def integrand(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
# Stack the dndz of all redshift bins
dndz = np.stack([pz(z_prime) for pz in pzs], axis=0)
return dndz * np.clip(chi_prime - chi, 0) / (chi_prime + 1e-5)
return np.squeeze(simps(integrand, z, zmax, 256))
return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.)

result = simps(integrand, z, zmax, 256) * (1. + z ) * chi

return np.squeeze(result)

def ell_factor(self, ell):
"""
Expand All @@ -84,20 +89,77 @@ def noise(self):
"""
# Extract parameters
pzs = self.params[0]
print("Careful! lensing noise is still a placeholder")
# TODO: find where to store this
sigma_e = 0.26
n_eff = 20.

# retrieve number of galaxies in each bins
# ngals = np.array([pz.ngals for pz in pzs])
# TODO: find how to properly compute the noise contributions
ngals = np.array([1. for pz in pzs])
ngals = np.array([pz.gals_per_steradian for pz in pzs])

# TODO: add mechanism for effective number density, maybe a bin dependent
# efficiency
return self.config['sigma_e']**2 / ngals

# compute n_eff per bin
n_eff_per_bin = n_eff * ngals/np.sum(ngals)

# work out the number density per steradian
steradian_to_arcmin2 = 11818102.86004228
@register_pytree_node_class
class NumberCounts(container):
"""
Class representing a galaxy clustering probe, with a bunch of bins
Parameters:
-----------
redshift_bins: nzredshift distributions
Configuration:
--------------
has_rsd....
"""
def __init__(self, redshift_bins, bias,
has_rsd=False,
**kwargs):
super(NumberCounts, self).__init__(redshift_bins,
bias,
has_rsd=has_rsd,
**kwargs)
@property
def n_tracers(self):
"""
Returns the number of tracers for this probe, i.e. redshift bins
"""
# Extract parameters
pzs = self.params[0]
return len(pzs)

def constant_factor(self, cosmo):
return 1.0

return sigma_e**2/(n_eff_per_bin * steradian_to_arcmin2)
@jit
def radial_kernel(self, cosmo, z):
"""
Compute the radial kernel for all nz bins in this probe.
Returns:
--------
radial_kernel: shape (nbins, nz)
"""
z = np.atleast_1d(z)
# Extract parameters
pzs, bias = self.params

# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)

return dndz * bias(z) * bkgrd.H(cosmo, z2a(z))

def ell_factor(self, ell):
"""
Computes the ell dependent factor for this probe.
"""
return 1.

def noise(self):
"""
Returns the noise power for all redshifts
return: shape [nbins]
"""
# Extract parameters
pzs = self.params[0]
ngals = np.array([pz.gals_per_steradian for pz in pzs])
return 1./ngals
40 changes: 37 additions & 3 deletions jax_cosmo/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.jax_utils import container

steradian_to_arcmin2 = 11818102.86004228

__all__ = ["smail_nz"]

class redshift_distribution(container):

def __init__(self, *args, zmax=10., **kwargs):
def __init__(self, *args, gals_per_arcmin2=1., zmax=10., **kwargs):
"""
Initialize the parameters of the redshift distribution
"""
self._norm = None
self._gals_per_arcmin2 = gals_per_arcmin2
super(redshift_distribution, self).__init__(*args,
zmax=zmax,
**kwargs)
Expand All @@ -35,7 +40,34 @@ def __call__(self, z):
@property
def zmax(self):
return self.config['zmax']


@property
def gals_per_arcmin2(self):
"""
Returns the number density of galaxies in gals/sq arcmin
TODO: find a better name
"""
return self._gals_per_arcmin2

@property
def gals_per_steradian(self):
"""
Returns the number density of galaxies in steradian
"""
return self._gals_per_arcmin2 * steradian_to_arcmin2

# Operations for flattening/unflattening representation
def tree_flatten(self):
children = (self.params, self._gals_per_arcmin2)
aux_data = self.config
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
args, gals_per_arcmin2 = children
return cls(*args, gals_per_arcmin2=gals_per_arcmin2,
**aux_data)

@register_pytree_node_class
class smail_nz(redshift_distribution):
"""
Expand All @@ -47,7 +79,9 @@ class smail_nz(redshift_distribution):
b:
z0
z0:
gals_per_arcmin2: number of galaxies per sq arcmin
"""
def pz_fn(self, z):
a, b, z0 = self.params
Expand Down
Loading

0 comments on commit ea3ccf3

Please sign in to comment.