diff --git a/README.md b/README.md index dc79252b..5c9c83c9 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ print(f"Estimate by KSG: {ksg.estimate(X, Y):.2f}") ### Evaluating a new estimator The above code snippet may be convenient for estimating mutual information on a given data set or for the development of a new mutual information estimator. -However, for extensive benchmarking it may be more convenient to use one benchmark suites available in the `workflows/benchmark/` subdirectory. +However, for extensive benchmarking it may be more convenient to use one of the benchmark suites available in the `workflows/benchmark/` subdirectory. For example, you can install [Snakemake](https://snakemake.readthedocs.io/en/stable/) and run a small benchmark suite on several estimators using: @@ -71,6 +71,8 @@ We defined several benchmark suites with shared structure. ## List of implemented estimators +(Your estimator can be here too! Please, reach out to us if you would like to contribute.) + - The neighborhood-based KSG estimator proposed in [Estimating Mutual Information](https://arxiv.org/abs/cond-mat/0305641) by Kraskov et al. (2003). - Donsker-Varadhan and MINE estimators proposed in [MINE: Mutual Information Neural Estimation](https://arxiv.org/abs/1801.04062) by Belghazi et al. (2018). - InfoNCE estimator proposed in [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748) by Oord et al. (2018). diff --git a/docs/benchmarking-new-estimator.md b/docs/benchmarking-new-estimator.md index 652923cc..298cb784 100644 --- a/docs/benchmarking-new-estimator.md +++ b/docs/benchmarking-new-estimator.md @@ -17,6 +17,50 @@ The `config.py` file is plain Python! We defined several benchmark suites with the shared structure. +## Adding a new Python estimator + +Every added estimator has to implement the [`IMutualInformationPointEstimator`](https://github.com/cbg-ethz/bmi/blob/main/src/bmi/interface.py) interface. + +Let's take a look at the simplest possible estimator (which generally shouldn't be used in practice), which estimates mutual information as $\hat I(X; Y) = -0.5 \log\left( 1-\mathrm{Corr}(X_1, Y_1)^2 \right)$. + +```{python} +import numpy as np + +from bmi.interface import BaseModel, IMutualInformationPointEstimator + +class EmptyParams(BaseModel): + """This estimator has no hyperparameters.""" + pass + +class TooSimpleEstimator(IMutualInformationPointEstimator): + def __init__(self) -> None: + """All estimator hyperparameters should be set at this stage. + This estimator does not have any hyperparameters, though. + """ + pass + + def estimate(self, x, y) -> float: + """Estimates mutual information. + + Attrs: + x: samples from the X variable, shape (n_points, dim_x) + y: samples from the Y variable, shape (n_points, dim_y) + """ + x, y = np.asarray(x), np.asarray(y) + x1 = x[:, 0] + y1 = y[:, 0] + + rho = np.corrcoef(x1, y1)[0, 1] + return -0.5 * np.log(1 - rho**2) + + def parameters(self) -> BaseModel: + """Returns the hyperparameters of the estimator.""" + return EmptyParams() +``` + +If your estimator is a function, you can also wrap it into a class using [`FunctionalEstimator`](https://github.com/cbg-ethz/bmi/blob/main/src/bmi/estimators/function_wrapper.py) wrapper. + +Once such a class is available, it can be simply included into the configuration dictionary. ## Adding a non-Python estimator