diff --git a/.all-contributorsrc b/.all-contributorsrc
index 35446a7..8ce102b 100644
--- a/.all-contributorsrc
+++ b/.all-contributorsrc
@@ -23,6 +23,24 @@
"bug",
"code"
]
+ },
+ {
+ "login": "austinpeel",
+ "name": "Austin Peel",
+ "avatar_url": "https://avatars0.githubusercontent.com/u/17024310?v=4",
+ "profile": "https://github.com/austinpeel",
+ "contributions": [
+ "code"
+ ]
+ },
+ {
+ "login": "minaskar",
+ "name": "Minas Karamanis",
+ "avatar_url": "https://avatars2.githubusercontent.com/u/23280751?v=4",
+ "profile": "https://minaskaramanis.com",
+ "contributions": [
+ "code"
+ ]
}
],
"contributorsPerLine": 7,
diff --git a/README.md b/README.md
index f21702b..c0c8c1e 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# jax-cosmo
[](https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [](https://jax-cosmo.readthedocs.io/en/latest/?badge=latest) []() [](https://github.com/psf/black) [](https://pypi.org/project/jax-cosmo/) [](https://github.com/google/jax-cosmo/blob/master/LICENSE)
-[](#contributors-)
+[](#contributors-)
Finally a differentiable cosmology library, and it's in JAX!
@@ -92,6 +92,8 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
 Francois Lanusse 💻 |
 Santiago Casas 🐛 💻 |
+  Austin Peel 💻 |
+  Minas Karamanis 💻 |
diff --git a/jax_cosmo/likelihood.py b/jax_cosmo/likelihood.py
index cac1c08..5ad295a 100644
--- a/jax_cosmo/likelihood.py
+++ b/jax_cosmo/likelihood.py
@@ -4,6 +4,7 @@
from __future__ import print_function
import jax.numpy as np
+import jax.scipy as sp
from jax_cosmo.angular_cl import gaussian_cl_covariance
@@ -18,14 +19,14 @@ def gaussian_log_likelihood(data, mu, C, constant_cov=True, inverse_method="inve
# TODO: check what is the fastest and works the best between cholesky+solve
# and just inversion
if inverse_method == "inverse":
- y = np.linalg.inv(C) @ r
+ y = np.dot(np.linalg.inv(C), r)
elif inverse_method == "cholesky":
- raise NotImplementedError
+ y = sp.linalg.cho_solve(sp.linalg.cho_factor(C, lower=True), r)
else:
raise NotImplementedError
if constant_cov:
- return -0.5 * (r.T @ y)
+ return -0.5 * r.dot(y)
else:
_, logdet = np.linalg.slogdet(C)
- return -0.5 * (r.T @ y) - 0.5 * logdet
+ return -0.5 * r.dot(y) - 0.5 * logdet
diff --git a/tests/test_angular_cl.py b/tests/test_angular_cl.py
index 1b96d75..53a1a6d 100644
--- a/tests/test_angular_cl.py
+++ b/tests/test_angular_cl.py
@@ -49,7 +49,7 @@ def test_lensing_cl():
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)
cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax])
- assert_allclose(cl_ccl, cl_jax[0], rtol=5e-3)
+ assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2)
def test_lensing_cl_IA():
@@ -142,4 +142,4 @@ def test_clustering_cl():
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)
cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax])
- assert_allclose(cl_ccl, cl_jax[0], rtol=5e-3)
+ assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2)