Skip to content

Commit

Permalink
new toy example, review fixes, sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
homerjed committed Jan 4, 2025
1 parent f5eec6a commit 2bdc33b
Show file tree
Hide file tree
Showing 12 changed files with 821 additions and 130 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ __pycache__/
.DS_Store

affine.py
_experiment.py
_experiment.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ and have a look at [examples](https://github.com/homerjed/sbiax/tree/main/exampl

### Citation

If you found this library to be useful in academic work, then please cite: <!--([arXiv link](https://arxiv.org/abs/2111.00254)) -->
If you found this library to be useful in academic work, please cite: <!--([arXiv link](https://arxiv.org/abs/2111.00254)) -->

```bibtex
@misc{homer2024simulationbasedinferencedodelsonschneidereffect,
Expand Down
11 changes: 4 additions & 7 deletions data/shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import tensorflow_probability.substrates.jax.distributions as tfd


def get_shear_experiment():
data_dir = "/Users/Jed.Homer/phd/sbiax/data/shear/"

covariance = np.loadtxt(data_dir + "covariance_cosmic_shear_PMEpaper.dat")
def get_shear_experiment(data_dir):
covariance = np.loadtxt(data_dir + "covariance_cosmic_shear_PMEpaper.dat") # / (18. / 5.)
precision = np.linalg.inv(np.matrix(covariance))
mu = np.loadtxt(data_dir + "DES_shear-shear_a1.0_b0.5_data_vector")[:, 1]
derivatives = np.loadtxt(data_dir + "derivatives.dat").T
Expand Down Expand Up @@ -39,7 +37,6 @@ def get_shear_experiment():


def linearized_model(_alpha, mu, alpha, derivatives):
""" Linearised model always uses true mu_0, C """
return mu + jnp.dot(_alpha - alpha, derivatives)


Expand Down Expand Up @@ -100,7 +97,7 @@ def _mle(d, pi, Finv, mu, dmu, precision):
return pi + jnp.linalg.multi_dot([Finv, dmu, precision, d - mu])


def get_experiment_data(key, true_covariance, n_sims, *, results_dir):
def get_experiment_data(key, true_covariance, n_sims, *, results_dir, data_dir):

key, key_prior, key_simulate = jr.split(key, 3)

Expand All @@ -116,7 +113,7 @@ def get_experiment_data(key, true_covariance, n_sims, *, results_dir):
Finv,
lower,
upper
) = get_shear_experiment()
) = get_shear_experiment(data_dir=data_dir)

# Estimate covariance, Fisher information and precision given n_sims
if true_covariance:
Expand Down
45 changes: 26 additions & 19 deletions examples/shear.ipynb

Large diffs are not rendered by default.

141 changes: 98 additions & 43 deletions examples/shear_nn.ipynb

Large diffs are not rendered by default.

79 changes: 41 additions & 38 deletions examples/shear_optuna.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 2bdc33b

Please sign in to comment.