-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from DifferentiableUniverseInitiative/number_c…
…ounts Adds clustering probe
- Loading branch information
Showing
9 changed files
with
1,150 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,81 @@ | ||
# jax_cosmo | ||
# jax-cosmo | ||
|
||
[![Join the chat at https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo](https://badges.gitter.im/DifferentiableUniverseInitiative/jax_cosmo.svg)](https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ![Python package](https://github.com/DifferentiableUniverseInitiative/jax_cosmo/workflows/Python%20package/badge.svg) | ||
|
||
A differentiable cosmology library in JAX | ||
A differentiable cosmology library in JAX. | ||
|
||
This reposotory is a prototype for a differentiable cosmology library written in Jax. For this prototype we are going to keep things simple, and port as much exciting code as possible from a pure Python package: https://github.com/cosmicpy/cosmicpy | ||
**Note**: This package is still in the development phase, expect changes to the API. We hope to make this project a community effort, contributions of all kind are most welcome! | ||
Have a look at the [GitHub issues](https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues) to see what is needed or if you have any thoughts on the design, and don't hesitate to join the [Gitter room](https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo) for discussions. | ||
|
||
## TL;DR | ||
|
||
This is what `jax-cosmo` aims to do: | ||
|
||
```python | ||
data = #... some measured Cl data vector | ||
nz1,nz2,nz3,nz4 = #.... redshift distributions of bins | ||
def likelihood(cosmo): | ||
# Define a list of probes | ||
probes = [jax_cosmo.probes.WeakLensing([nz1, nz2, nz3, nz4]), | ||
jax_cosmo.probes.NumberCounts([nz1, nz2, nz3, nz4])] | ||
|
||
# Compute mean and covariance of angular Cls | ||
mu, cov = jax_cosmo.angular_cl.gaussian_cl_covariance(cosmo, ell, probes) | ||
|
||
# Return likelihood value | ||
return jax_cosmo.likelihood.gaussian_log_likelihood(data, mu, cov) | ||
|
||
# Compute derivatives of the likelihood with respect to cosmological parameters | ||
g = jax.grad(likelihood)(cosmo) | ||
|
||
# Compute Fisher matrix of cosmological parameters | ||
F = - jax.hessian(likelihood)(cosmo) | ||
``` | ||
This is how you can compute gradients and hessians of any functions in `jax-cosmo`, | ||
all of this without any finite differences. | ||
|
||
Check out a full example here: | ||
|
||
## What is JAX? | ||
|
||
[JAX](https://github.com/google/jax) = NumPy + autodiff + GPU | ||
|
||
JAX is a framework for automatic differentiation (like TensorFlow or PyTorch) but following the NumPy API, and using the GPU/TPU enable XLA backend. | ||
|
||
What does that mean? | ||
- You write plain Python/NumPy code, no need to learn a different language | ||
- It runs on GPU, you don't need to do anything particular | ||
- You can take derivatives of any quantity with respect to any parameters by | ||
automatic differentiation. | ||
|
||
Checkout the [JAX](https://github.com/google/jax) project page to learn more! | ||
|
||
## Install | ||
|
||
`jax-cosmo` is pure Python, so installing is a breeze: | ||
```bash | ||
$ pip install jax-cosmo | ||
``` | ||
|
||
## Philosophy | ||
|
||
Here are some of the design guidelines: | ||
- Implementation of equations should be human readable, and documentation should always live next to the implementation. | ||
- Should always be trivially installable: external dependencies should be kept | ||
to a minimum, especially the ones that require compilation or with restrictive licenses. | ||
- Keep API and implementation simple and intuitive, minimize user and developer | ||
surprise. | ||
- “Debugging is twice as hard as writing the code in the first place. Therefore, if you write the code as cleverly as possible, you are, by definition, not smart enough to debug it.” -Brian Kernighan, quote stolen from | ||
[here](https://flax.readthedocs.io/en/latest/philosophy.html). | ||
|
||
## Contributing | ||
|
||
`jax-cosmo` aims to be a community effort, contributions are most welcome and | ||
can come in several forms | ||
- Bug reports | ||
- API design suggestions | ||
- (Pull) requests for more features | ||
- Examples and notebooks of cool things that can be done with the code | ||
|
||
The issue page is a good place to start, but don't hesitate to come chat in the | ||
Gitter room. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# This module contains implementations of galaxy bias | ||
import jax.numpy as np | ||
from jax.tree_util import register_pytree_node_class | ||
|
||
from jax_cosmo.jax_utils import container | ||
|
||
@register_pytree_node_class | ||
class constant_linear_bias(container): | ||
""" | ||
Class representing a linear bias | ||
Parameters: | ||
----------- | ||
b: redshift independent bias value | ||
""" | ||
def __call__(self, z): | ||
""" | ||
""" | ||
b = self.params[0] | ||
return b * np.ones_like(z) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.