Skip to content

Commit

Permalink
Merge pull request #2 from mj-will/add-initial-template
Browse files Browse the repository at this point in the history
Add initial template
  • Loading branch information
mj-will authored Jan 12, 2024
2 parents e3863da + 9acf5c9 commit 254bab6
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 0 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Unit tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
unittests:
name: Unit tests - Python ${{ matrix.python-version }}

strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[test]
- name: Test with pytest
run: |
python -m pytest
72 changes: 72 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,74 @@
# sampler-template

Template for sampler plugins in bilby

There are three main components to a sampler plugin for bilby:

- The sampler class
- The `pyproject.toml`
- The test suite


## The sampler class

This is the interface between the external sampling code and bilby, in this
template repo is it located in `src/demo_sampler_bilby/plugin.py`.

The class should inherit from one of:

- `bilby.core.sampler.Sampler`
- `bilby.core.sampler.MCMCSampler`
- `bilby.core.sampler.NestedSampler`

The sampler is run in the `run_sampler` method


## pyproject.toml

Various fields in the `pyproject.toml` need to be set, these are:

* `name`: this is the name Python package, we recommend using `<name of your sampler>_bilby`
* `author`: this should include any authors
* `dependencies`: any dependencies should be included here, this must include `bilby`
* `[project.entry-points."bilby.samplers"]` see below


### Adding the entry points

This section of the `pyproject.toml` makes the sampler 'visible' within bilby.

```
[project.entry-points."bilby.samplers"]
demo_sampler = "demo_sampler_bilby.plugin:DemoSampler"
```

The name of the sampler within bilby is determined based on the name used here
(e.g. `demo_sampler` in this case).

The string should points to the file and, after the colon, the sampler class.


## Tests

The plugin should include a test suite.

`tests/test_bilby_integration.py` includes a standard test that
all samplers should pass but other tests can be included in here. The file should
be updated name matches the name of the sampler provided by the plugin and any keyword
arguments should be set.


### Continuous Integration

The tests are run automatically via GitHub Actions. These are configured in
`.github/workflow/test.yaml`. You may wish to change the version of Python
being tested or included additional configuration, e.g. for more complex installation
processes.


## Plugin in an existing package

It is also possible to include the plugin in an existing package rather
than creating a separate plugin package. In this case, you need to define
the sampler class somewhere within the existing package and then add an entry
point.
35 changes: 35 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
[build-system]
requires = ["setuptools>=45", "setuptools-scm[toml]>=6.2"]
build-backend = "setuptools.build_meta"

[project]
name = "demo_sampler_bilby"
authors = [
{name = "your name", email = "[email protected]"},
]
description = ""
readme = "README.md"
requires-python = ">=3.9"
license = {text = "MIT"}
classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"bilby",
"numpy",
]

dynamic = ["version"]

[project.optional-dependencies]
test = [
"pytest",
]

[tool.setuptools_scm]

[tool.black]
line-length = 79

[project.entry-points."bilby.samplers"]
demo_sampler = "demo_sampler_bilby.plugin:DemoSampler"
11 changes: 11 additions & 0 deletions src/demo_sampler_bilby/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""An example of how to implement a sampler plugin in for bilby.
This package provides the 'demo_sampler' sampler.
"""
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version(__name__)
except PackageNotFoundError:
# package is not installed
__version__ = "unknown"
77 changes: 77 additions & 0 deletions src/demo_sampler_bilby/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Example plugin for using a sampler in bilby.
Here we demonstrate the how to implement the class.
"""
import bilby
from bilby.core.sampler.base_sampler import MCMCSampler, NestedSampler
import numpy as np


class DemoSampler(NestedSampler):
"""Bilby wrapper for your sampler.
This class should inherit from :code:`MCMCSampler` or :code:`NestedSampler`
"""

@property
def external_sampler_name(self) -> str:
"""The name of package that provides the sampler."""
# In this template we do not require any external codes, so we just
# use bilby. You should change this.
return "bilby"

@property
def default_kwargs(self) -> dict:
"""Dictionary of default keyword arguments.
Any arguments not included here will be removed before calling the
sampler.
"""
return dict(
ninitial=100,
)

def run_sampler(self) -> dict:
"""Run the sampler.
This method should run the sampler and update the result object.
It should also return the result object.
"""

# The code below shows how you can call different methods.
# Replace this code with calls to your sampler

# Keyword arguments are stored in self.kwargs
prior_samples = np.array(
list(self.priors.sample(self.kwargs["ninitial"]).values()),
).T
# We can evaluate the log-prior and log-likelihood
logl = np.empty(len(prior_samples))
logp = np.empty(len(prior_samples))
for i, sample in enumerate(prior_samples):
logl[i] = self.log_likelihood(sample)
logp[i] = self.log_prior(sample)

# Generate posterior samples
logw = logl.copy() - logl.max()
keep = logw > np.log(np.random.rand(len(logw)))
posterior_samples = prior_samples[keep]

# The result object is created automatically
# So we just have to populate the different methods
# Add the posterior samples to the result object
# This should be a numpy array of shape (# samples x # parameters)
self.result.samples = posterior_samples
# We can also store the log-likelihood and log-prior values for each
# posterior sample
self.result.log_likelihood_evaluations = logl[keep]
self.result.log_prior_evaluations = logp[keep]
# If it is a nested sampler, we can add the nested samples
self.result.nested_samples = prior_samples
# We can also add the log-evidence and the error
# These can be NaNs for samplers that no not estimate the evidence
self.result.log_evidence = np.mean(logl)
self.result.log_evidence_err = np.std(logl)

# Must return the result object
return self.result
46 changes: 46 additions & 0 deletions tests/test_bilby_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import bilby
import numpy as np
import pytest


def model(x, m, c):
return m * x + c


@pytest.fixture()
def bilby_likelihood():
bilby.core.utils.random.seed(42)
rng = bilby.core.utils.random.rng
x = np.linspace(0, 1, 11)
injection_parameters = dict(m=0.5, c=0.2)
sigma = 0.1
y = model(x, **injection_parameters) + rng.normal(0.0, sigma, len(x))
likelihood = bilby.likelihood.GaussianLikelihood(x, y, model, sigma)
return likelihood


@pytest.fixture()
def bilby_priors():
priors = bilby.core.prior.PriorDict()
priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic")
priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective")
return priors


@pytest.fixture()
def sampler_kwargs():
# Any keyword arguments that need to be set of you want to test
return dict(
ninitial=100,
)


def test_run_sampler(bilby_likelihood, bilby_priors, tmp_path, sampler_kwargs):
outdir = tmp_path / "test_run_sampler"

bilby.run_sampler(
likelihood=bilby_likelihood,
priors=bilby_priors,
sampler="demo_sampler", # This should match the name of the sampler
outdir=outdir,
)

0 comments on commit 254bab6

Please sign in to comment.