-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from mj-will/add-initial-template
Add initial template
- Loading branch information
Showing
6 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
name: Unit tests | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
unittests: | ||
name: Unit tests - Python ${{ matrix.python-version }} | ||
|
||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.9", "3.10", "3.11"] | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
cache: 'pip' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install .[test] | ||
- name: Test with pytest | ||
run: | | ||
python -m pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,74 @@ | ||
# sampler-template | ||
|
||
Template for sampler plugins in bilby | ||
|
||
There are three main components to a sampler plugin for bilby: | ||
|
||
- The sampler class | ||
- The `pyproject.toml` | ||
- The test suite | ||
|
||
|
||
## The sampler class | ||
|
||
This is the interface between the external sampling code and bilby, in this | ||
template repo is it located in `src/demo_sampler_bilby/plugin.py`. | ||
|
||
The class should inherit from one of: | ||
|
||
- `bilby.core.sampler.Sampler` | ||
- `bilby.core.sampler.MCMCSampler` | ||
- `bilby.core.sampler.NestedSampler` | ||
|
||
The sampler is run in the `run_sampler` method | ||
|
||
|
||
## pyproject.toml | ||
|
||
Various fields in the `pyproject.toml` need to be set, these are: | ||
|
||
* `name`: this is the name Python package, we recommend using `<name of your sampler>_bilby` | ||
* `author`: this should include any authors | ||
* `dependencies`: any dependencies should be included here, this must include `bilby` | ||
* `[project.entry-points."bilby.samplers"]` see below | ||
|
||
|
||
### Adding the entry points | ||
|
||
This section of the `pyproject.toml` makes the sampler 'visible' within bilby. | ||
|
||
``` | ||
[project.entry-points."bilby.samplers"] | ||
demo_sampler = "demo_sampler_bilby.plugin:DemoSampler" | ||
``` | ||
|
||
The name of the sampler within bilby is determined based on the name used here | ||
(e.g. `demo_sampler` in this case). | ||
|
||
The string should points to the file and, after the colon, the sampler class. | ||
|
||
|
||
## Tests | ||
|
||
The plugin should include a test suite. | ||
|
||
`tests/test_bilby_integration.py` includes a standard test that | ||
all samplers should pass but other tests can be included in here. The file should | ||
be updated name matches the name of the sampler provided by the plugin and any keyword | ||
arguments should be set. | ||
|
||
|
||
### Continuous Integration | ||
|
||
The tests are run automatically via GitHub Actions. These are configured in | ||
`.github/workflow/test.yaml`. You may wish to change the version of Python | ||
being tested or included additional configuration, e.g. for more complex installation | ||
processes. | ||
|
||
|
||
## Plugin in an existing package | ||
|
||
It is also possible to include the plugin in an existing package rather | ||
than creating a separate plugin package. In this case, you need to define | ||
the sampler class somewhere within the existing package and then add an entry | ||
point. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
[build-system] | ||
requires = ["setuptools>=45", "setuptools-scm[toml]>=6.2"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "demo_sampler_bilby" | ||
authors = [ | ||
{name = "your name", email = "[email protected]"}, | ||
] | ||
description = "" | ||
readme = "README.md" | ||
requires-python = ">=3.9" | ||
license = {text = "MIT"} | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
] | ||
dependencies = [ | ||
"bilby", | ||
"numpy", | ||
] | ||
|
||
dynamic = ["version"] | ||
|
||
[project.optional-dependencies] | ||
test = [ | ||
"pytest", | ||
] | ||
|
||
[tool.setuptools_scm] | ||
|
||
[tool.black] | ||
line-length = 79 | ||
|
||
[project.entry-points."bilby.samplers"] | ||
demo_sampler = "demo_sampler_bilby.plugin:DemoSampler" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""An example of how to implement a sampler plugin in for bilby. | ||
This package provides the 'demo_sampler' sampler. | ||
""" | ||
from importlib.metadata import PackageNotFoundError, version | ||
|
||
try: | ||
__version__ = version(__name__) | ||
except PackageNotFoundError: | ||
# package is not installed | ||
__version__ = "unknown" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""Example plugin for using a sampler in bilby. | ||
Here we demonstrate the how to implement the class. | ||
""" | ||
import bilby | ||
from bilby.core.sampler.base_sampler import MCMCSampler, NestedSampler | ||
import numpy as np | ||
|
||
|
||
class DemoSampler(NestedSampler): | ||
"""Bilby wrapper for your sampler. | ||
This class should inherit from :code:`MCMCSampler` or :code:`NestedSampler` | ||
""" | ||
|
||
@property | ||
def external_sampler_name(self) -> str: | ||
"""The name of package that provides the sampler.""" | ||
# In this template we do not require any external codes, so we just | ||
# use bilby. You should change this. | ||
return "bilby" | ||
|
||
@property | ||
def default_kwargs(self) -> dict: | ||
"""Dictionary of default keyword arguments. | ||
Any arguments not included here will be removed before calling the | ||
sampler. | ||
""" | ||
return dict( | ||
ninitial=100, | ||
) | ||
|
||
def run_sampler(self) -> dict: | ||
"""Run the sampler. | ||
This method should run the sampler and update the result object. | ||
It should also return the result object. | ||
""" | ||
|
||
# The code below shows how you can call different methods. | ||
# Replace this code with calls to your sampler | ||
|
||
# Keyword arguments are stored in self.kwargs | ||
prior_samples = np.array( | ||
list(self.priors.sample(self.kwargs["ninitial"]).values()), | ||
).T | ||
# We can evaluate the log-prior and log-likelihood | ||
logl = np.empty(len(prior_samples)) | ||
logp = np.empty(len(prior_samples)) | ||
for i, sample in enumerate(prior_samples): | ||
logl[i] = self.log_likelihood(sample) | ||
logp[i] = self.log_prior(sample) | ||
|
||
# Generate posterior samples | ||
logw = logl.copy() - logl.max() | ||
keep = logw > np.log(np.random.rand(len(logw))) | ||
posterior_samples = prior_samples[keep] | ||
|
||
# The result object is created automatically | ||
# So we just have to populate the different methods | ||
# Add the posterior samples to the result object | ||
# This should be a numpy array of shape (# samples x # parameters) | ||
self.result.samples = posterior_samples | ||
# We can also store the log-likelihood and log-prior values for each | ||
# posterior sample | ||
self.result.log_likelihood_evaluations = logl[keep] | ||
self.result.log_prior_evaluations = logp[keep] | ||
# If it is a nested sampler, we can add the nested samples | ||
self.result.nested_samples = prior_samples | ||
# We can also add the log-evidence and the error | ||
# These can be NaNs for samplers that no not estimate the evidence | ||
self.result.log_evidence = np.mean(logl) | ||
self.result.log_evidence_err = np.std(logl) | ||
|
||
# Must return the result object | ||
return self.result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import bilby | ||
import numpy as np | ||
import pytest | ||
|
||
|
||
def model(x, m, c): | ||
return m * x + c | ||
|
||
|
||
@pytest.fixture() | ||
def bilby_likelihood(): | ||
bilby.core.utils.random.seed(42) | ||
rng = bilby.core.utils.random.rng | ||
x = np.linspace(0, 1, 11) | ||
injection_parameters = dict(m=0.5, c=0.2) | ||
sigma = 0.1 | ||
y = model(x, **injection_parameters) + rng.normal(0.0, sigma, len(x)) | ||
likelihood = bilby.likelihood.GaussianLikelihood(x, y, model, sigma) | ||
return likelihood | ||
|
||
|
||
@pytest.fixture() | ||
def bilby_priors(): | ||
priors = bilby.core.prior.PriorDict() | ||
priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic") | ||
priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective") | ||
return priors | ||
|
||
|
||
@pytest.fixture() | ||
def sampler_kwargs(): | ||
# Any keyword arguments that need to be set of you want to test | ||
return dict( | ||
ninitial=100, | ||
) | ||
|
||
|
||
def test_run_sampler(bilby_likelihood, bilby_priors, tmp_path, sampler_kwargs): | ||
outdir = tmp_path / "test_run_sampler" | ||
|
||
bilby.run_sampler( | ||
likelihood=bilby_likelihood, | ||
priors=bilby_priors, | ||
sampler="demo_sampler", # This should match the name of the sampler | ||
outdir=outdir, | ||
) |