From a8542c0c0cb93a3c7a1764868eb65f308dc474fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Wed, 29 May 2024 11:36:53 +0200 Subject: [PATCH] Rename the fine submodule to bmm in src/ and tests/ --- pyproject.toml | 5 +++ src/bmi/benchmark/tasks/mixtures.py | 50 ++++++++++++++--------------- src/bmi/estimators/external/gmm.py | 30 +++++++++++------ src/bmi/samplers/__init__.py | 4 +-- src/bmi/samplers/_tfp/_core.py | 7 ++-- src/bmi/samplers/_tfp/_wrapper.py | 6 ++-- tests/samplers/tfp/test_product.py | 6 ++-- tests/samplers/tfp/test_wrapper.py | 6 ++-- 8 files changed, 65 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5758a00f..ed4b0d25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,11 @@ scipy = "^1.10.1" tqdm = "^4.64.1" tensorflow-probability = {extras = ["jax"], version = "^0.20.1"} +[tool.poetry.group.bayes] +optional = true + +[tool.poetry.group.bayes.dependencies] +numpyro = "^0.14.0" [tool.poetry.group.dev] optional = true diff --git a/src/bmi/benchmark/tasks/mixtures.py b/src/bmi/benchmark/tasks/mixtures.py index 4561229b..f1e29f21 100644 --- a/src/bmi/benchmark/tasks/mixtures.py +++ b/src/bmi/benchmark/tasks/mixtures.py @@ -4,7 +4,7 @@ import bmi.samplers as samplers import bmi.transforms as transforms from bmi.benchmark.task import Task -from bmi.samplers import fine +from bmi.samplers import bmm _MC_MI_ESTIMATE_SAMPLE = 100_000 @@ -15,10 +15,10 @@ def task_x( ) -> Task: """The X distribution.""" - dist = fine.mixture( + dist = bmm.mixture( proportions=jnp.array([0.5, 0.5]), components=[ - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( covariance=samplers.canonical_correlation([x * gaussian_correlation]), mean=jnp.zeros(2), dim_x=1, @@ -27,7 +27,7 @@ def task_x( for x in [-1, 1] ], ) - sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample) + sampler = bmm.FineSampler(dist, mi_estimate_sample=mi_estimate_sample) return Task( sampler=sampler, @@ -47,36 +47,36 @@ def task_ai( corr = 0.95 var_x = 0.04 - dist = fine.mixture( + dist = bmm.mixture( proportions=jnp.full(6, fill_value=1 / 6), components=[ # I components - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( dim_x=1, dim_y=1, mean=jnp.array([1.0, 0.0]), covariance=np.diag([0.01, 0.2]), ), - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( dim_x=1, dim_y=1, mean=jnp.array([1.0, 1]), covariance=np.diag([0.05, 0.001]), ), - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( dim_x=1, dim_y=1, mean=jnp.array([1.0, -1]), covariance=np.diag([0.05, 0.001]), ), # A components - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( dim_x=1, dim_y=1, mean=jnp.array([-0.8, -0.2]), covariance=np.diag([0.03, 0.001]), ), - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( dim_x=1, dim_y=1, mean=jnp.array([-1.2, 0.0]), @@ -84,7 +84,7 @@ def task_ai( [[var_x, jnp.sqrt(var_x * 0.2) * corr], [jnp.sqrt(var_x * 0.2) * corr, 0.2]] ), ), - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( dim_x=1, dim_y=1, mean=jnp.array([-0.4, 0.0]), @@ -94,7 +94,7 @@ def task_ai( ), ], ) - sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample) + sampler = bmm.FineSampler(dist, mi_estimate_sample=mi_estimate_sample) return Task( sampler=sampler, @@ -110,10 +110,10 @@ def task_galaxy( ) -> Task: """The Galaxy distribution.""" - balls_mixt = fine.mixture( + balls_mixt = bmm.mixture( proportions=jnp.array([0.5, 0.5]), components=[ - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( covariance=samplers.canonical_correlation([0.0], additional_y=1), mean=jnp.array([x, x, x]) * distance / 2, dim_x=2, @@ -123,7 +123,7 @@ def task_galaxy( ], ) - base_sampler = fine.FineSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample) + base_sampler = bmm.FineSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample) a = jnp.array([[0, -1], [1, 0]]) spiral = transforms.Spiral(a, speed=speed) @@ -150,10 +150,10 @@ def task_waves( assert n_components > 0 - base_dist = fine.mixture( + base_dist = bmm.mixture( proportions=jnp.ones(n_components) / n_components, components=[ - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( covariance=jnp.diag(jnp.array([0.1, 1.0, 0.1])), mean=jnp.array([x, 0, x % 4]) * 1.5, dim_x=2, @@ -162,7 +162,7 @@ def task_waves( for x in range(n_components) ], ) - base_sampler = fine.FineSampler(base_dist, mi_estimate_sample=mi_estimate_sample) + base_sampler = bmm.FineSampler(base_dist, mi_estimate_sample=mi_estimate_sample) aux_sampler = samplers.TransformedSampler( base_sampler, transform_x=lambda x: x @@ -193,10 +193,10 @@ def task_concentric_multinormal( assert n_components > 0 - dist = fine.mixture( + dist = bmm.mixture( proportions=jnp.ones(n_components) / n_components, components=[ - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( covariance=jnp.diag(jnp.array(dim_x * [i**2] + [0.0001])), mean=jnp.array(dim_x * [0.0] + [1.0 * i]), dim_x=dim_x, @@ -205,7 +205,7 @@ def task_concentric_multinormal( for i in range(1, 1 + n_components) ], ) - sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample) + sampler = bmm.FineSampler(dist, mi_estimate_sample=mi_estimate_sample) return Task( sampler=sampler, @@ -238,23 +238,23 @@ def task_multinormal_sparse_w_inliers( eta_x=strength, ) - signal_dist = fine.MultivariateNormalDistribution( + signal_dist = bmm.MultivariateNormalDistribution( dim_x=dim_x, dim_y=dim_y, covariance=params.correlation, ) - noise_dist = fine.ProductDistribution( + noise_dist = bmm.ProductDistribution( dist_x=signal_dist.dist_x, dist_y=signal_dist.dist_y, ) - dist = fine.mixture( + dist = bmm.mixture( proportions=jnp.array([1 - inlier_fraction, inlier_fraction]), components=[signal_dist, noise_dist], ) - sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample) + sampler = bmm.FineSampler(dist, mi_estimate_sample=mi_estimate_sample) task_id = f"mult-sparse-w-inliers-{dim_x}-{dim_y}-{n_interacting}-{strength}-{inlier_fraction}" return Task( diff --git a/src/bmi/estimators/external/gmm.py b/src/bmi/estimators/external/gmm.py index e3811ced..3760a7f9 100644 --- a/src/bmi/estimators/external/gmm.py +++ b/src/bmi/estimators/external/gmm.py @@ -1,3 +1,13 @@ +"""A Gaussian mixture model estimator, allowing for model-based +Bayesian estimator of mutual information. +The full description can be found [here](https://arxiv.org/abs/2310.10240). + +Note that to use this estimator you need to install external dependencies: +```bash +$ pip install benchmark-mi[bayes] +``` +""" + try: import numpyro # type: ignore import numpyro.distributions as dist # type: ignore @@ -12,7 +22,7 @@ from numpy.typing import ArrayLike from bmi.interface import BaseModel, IMutualInformationPointEstimator -from bmi.samplers import fine +from bmi.samplers import bmm from bmi.utils import ProductSpace @@ -74,14 +84,14 @@ def model( ) -def sample_into_fine_distribution( +def sample_into_bmm_distribution( means: jnp.ndarray, covariances: jnp.ndarray, proportions: jnp.ndarray, dim_x: int, dim_y: int, -) -> fine.JointDistribution: - """Builds a fine distribution from a Gaussian mixture model parameters.""" +) -> bmm.JointDistribution: + """Builds a bmm distribution from a Gaussian mixture model parameters.""" # Check if the dimensions are right n_components = proportions.shape[0] n_dims = dim_x + dim_y @@ -90,7 +100,7 @@ def sample_into_fine_distribution( # Build components components = [ - fine.MultivariateNormalDistribution( + bmm.MultivariateNormalDistribution( dim_x=dim_x, dim_y=dim_y, mean=mean, @@ -100,7 +110,7 @@ def sample_into_fine_distribution( ] # Build a mixture model - return fine.mixture(proportions=proportions, components=components) + return bmm.mixture(proportions=proportions, components=components) class GMMEstimatorParams(BaseModel): @@ -185,12 +195,12 @@ def run_mcmc(self, x: ArrayLike, y: ArrayLike): self._dim_x = space.dim_x self._dim_y = space.dim_y - def get_fine_distribution(self, idx: int) -> fine.JointDistribution: + def get_bmm_distribution(self, idx: int) -> bmm.JointDistribution: if self._mcmc is None: raise ValueError("You need to run MCMC first. See the `run_mcmc` method.") samples = self._mcmc.get_samples() - return sample_into_fine_distribution( + return sample_into_bmm_distribution( means=samples["mu"][idx], covariances=samples["cov"][idx], proportions=samples["pi"][idx], @@ -204,8 +214,8 @@ def get_sample_mi(self, idx: int, mc_samples: Optional[int] = None, key=None) -> if key is None: self.key, key = jax.random.split(self.key) - distribution = self.get_fine_distribution(idx) - mi, _ = fine.monte_carlo_mi_estimate(key=key, dist=distribution, n=mc_samples) + distribution = self.get_bmm_distribution(idx) + mi, _ = bmm.monte_carlo_mi_estimate(key=key, dist=distribution, n=mc_samples) return mi def get_posterior_mi( diff --git a/src/bmi/samplers/__init__.py b/src/bmi/samplers/__init__.py index 85bcb497..455af6ca 100644 --- a/src/bmi/samplers/__init__.py +++ b/src/bmi/samplers/__init__.py @@ -13,7 +13,7 @@ ) # isort: on -import bmi.samplers._tfp as fine +import bmi.samplers._tfp as bmm from bmi.samplers._independent_coordinates import IndependentConcatenationSampler from bmi.samplers._split_student_t import SplitStudentT from bmi.samplers._splitmultinormal import BivariateNormalSampler, SplitMultinormal @@ -33,7 +33,7 @@ "AdditiveUniformSampler", "BaseSampler", "canonical_correlation", - "fine", + "bmm", "parametrised_correlation_matrix", "BivariateNormalSampler", "SplitMultinormal", diff --git a/src/bmi/samplers/_tfp/_core.py b/src/bmi/samplers/_tfp/_core.py index 262b61e7..ed075af5 100644 --- a/src/bmi/samplers/_tfp/_core.py +++ b/src/bmi/samplers/_tfp/_core.py @@ -11,9 +11,10 @@ @dataclasses.dataclass class JointDistribution: - """The main object of this package. - Represents a joint distribution $P_{XY}$ together - with the marginal distributions $P_X$ and $P_Y$. + """The main object of this package, representing + a Bend and Mix Model (BMM), i.e., a joint distribution + $P_{XY}$ together with the marginal distributions + $P_X$ and $P_Y$. Attributes: dist: $P_{XY}$ diff --git a/src/bmi/samplers/_tfp/_wrapper.py b/src/bmi/samplers/_tfp/_wrapper.py index 4cc8ee79..4ea406a1 100644 --- a/src/bmi/samplers/_tfp/_wrapper.py +++ b/src/bmi/samplers/_tfp/_wrapper.py @@ -9,7 +9,7 @@ class FineSampler(BaseSampler): - """Wraps a given fine distribution into a sampler.""" + """Wraps a given Bend and Mix Model (BMM) into a sampler.""" def __init__( self, @@ -21,8 +21,8 @@ def __init__( """ Args: - dist: fine distribution to be wrapped - mi: mutual information of the fine distribution, if already calculated. + dist: distribution represented by a BMM to be wrapped + mi: mutual information of the distribution, if already calculated. If not provided, it will be estimated via Monte Carlo sampling. mi_estimate_seed: seed for the Monte Carlo sampling mi_estimate_sample: number of samples for the Monte Carlo sampling diff --git a/tests/samplers/tfp/test_product.py b/tests/samplers/tfp/test_product.py index 1f025da1..06ce1ef0 100644 --- a/tests/samplers/tfp/test_product.py +++ b/tests/samplers/tfp/test_product.py @@ -3,20 +3,20 @@ import pytest import bmi.samplers -from bmi.samplers import fine +from bmi.samplers import bmm def test_product_distribution(dim_x: int = 2, dim_y: int = 3, n_points: int = 10) -> None: assert dim_y >= dim_x, "We construct canonical correlation matrix, so we want this constraint." - dist_dependent = fine.MultivariateNormalDistribution( + dist_dependent = bmm.MultivariateNormalDistribution( dim_x=dim_x, dim_y=dim_y, covariance=bmi.samplers.canonical_correlation( jnp.full((dim_x,), fill_value=0.5), additional_y=dim_y - dim_x ), ) - dist_independent = fine.ProductDistribution( + dist_independent = bmm.ProductDistribution( dist_x=dist_dependent.dist_x, dist_y=dist_dependent.dist_y ) diff --git a/tests/samplers/tfp/test_wrapper.py b/tests/samplers/tfp/test_wrapper.py index ff257c78..46ef0758 100644 --- a/tests/samplers/tfp/test_wrapper.py +++ b/tests/samplers/tfp/test_wrapper.py @@ -1,16 +1,16 @@ import jax.numpy as jnp import pytest -from bmi.samplers import fine +from bmi.samplers import bmm def test_can_create_sampler() -> None: - dist = fine.MultivariateNormalDistribution( + dist = bmm.MultivariateNormalDistribution( dim_x=1, dim_y=1, covariance=jnp.asarray([[1, 0.5], [0.5, 1]]) ) mi = -0.5 * jnp.log(1 - 0.5**2) - sampler = fine.FineSampler(dist=dist, mi_estimate_seed=0, mi_estimate_sample=1_000) + sampler = bmm.FineSampler(dist=dist, mi_estimate_seed=0, mi_estimate_sample=1_000) x_sample, y_sample = sampler.sample(n_points=10, rng=0) assert x_sample.shape == (10, 1)