Skip to content

Commit

Permalink
Rename the fine submodule to bmm in src/ and tests/
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed May 29, 2024
1 parent 7fc5280 commit a8542c0
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 49 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ scipy = "^1.10.1"
tqdm = "^4.64.1"
tensorflow-probability = {extras = ["jax"], version = "^0.20.1"}

[tool.poetry.group.bayes]
optional = true

[tool.poetry.group.bayes.dependencies]
numpyro = "^0.14.0"

[tool.poetry.group.dev]
optional = true
Expand Down
50 changes: 25 additions & 25 deletions src/bmi/benchmark/tasks/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import bmi.samplers as samplers
import bmi.transforms as transforms
from bmi.benchmark.task import Task
from bmi.samplers import fine
from bmi.samplers import bmm

_MC_MI_ESTIMATE_SAMPLE = 100_000

Expand All @@ -15,10 +15,10 @@ def task_x(
) -> Task:
"""The X distribution."""

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=samplers.canonical_correlation([x * gaussian_correlation]),
mean=jnp.zeros(2),
dim_x=1,
Expand All @@ -27,7 +27,7 @@ def task_x(
for x in [-1, 1]
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand All @@ -47,44 +47,44 @@ def task_ai(
corr = 0.95
var_x = 0.04

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.full(6, fill_value=1 / 6),
components=[
# I components
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, 0.0]),
covariance=np.diag([0.01, 0.2]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, 1]),
covariance=np.diag([0.05, 0.001]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, -1]),
covariance=np.diag([0.05, 0.001]),
),
# A components
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-0.8, -0.2]),
covariance=np.diag([0.03, 0.001]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-1.2, 0.0]),
covariance=jnp.array(
[[var_x, jnp.sqrt(var_x * 0.2) * corr], [jnp.sqrt(var_x * 0.2) * corr, 0.2]]
),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-0.4, 0.0]),
Expand All @@ -94,7 +94,7 @@ def task_ai(
),
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand All @@ -110,10 +110,10 @@ def task_galaxy(
) -> Task:
"""The Galaxy distribution."""

balls_mixt = fine.mixture(
balls_mixt = bmm.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=samplers.canonical_correlation([0.0], additional_y=1),
mean=jnp.array([x, x, x]) * distance / 2,
dim_x=2,
Expand All @@ -123,7 +123,7 @@ def task_galaxy(
],
)

base_sampler = fine.FineSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample)
base_sampler = bmm.FineSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample)
a = jnp.array([[0, -1], [1, 0]])
spiral = transforms.Spiral(a, speed=speed)

Expand All @@ -150,10 +150,10 @@ def task_waves(

assert n_components > 0

base_dist = fine.mixture(
base_dist = bmm.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array([0.1, 1.0, 0.1])),
mean=jnp.array([x, 0, x % 4]) * 1.5,
dim_x=2,
Expand All @@ -162,7 +162,7 @@ def task_waves(
for x in range(n_components)
],
)
base_sampler = fine.FineSampler(base_dist, mi_estimate_sample=mi_estimate_sample)
base_sampler = bmm.FineSampler(base_dist, mi_estimate_sample=mi_estimate_sample)
aux_sampler = samplers.TransformedSampler(
base_sampler,
transform_x=lambda x: x
Expand Down Expand Up @@ -193,10 +193,10 @@ def task_concentric_multinormal(

assert n_components > 0

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array(dim_x * [i**2] + [0.0001])),
mean=jnp.array(dim_x * [0.0] + [1.0 * i]),
dim_x=dim_x,
Expand All @@ -205,7 +205,7 @@ def task_concentric_multinormal(
for i in range(1, 1 + n_components)
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand Down Expand Up @@ -238,23 +238,23 @@ def task_multinormal_sparse_w_inliers(
eta_x=strength,
)

signal_dist = fine.MultivariateNormalDistribution(
signal_dist = bmm.MultivariateNormalDistribution(
dim_x=dim_x,
dim_y=dim_y,
covariance=params.correlation,
)

noise_dist = fine.ProductDistribution(
noise_dist = bmm.ProductDistribution(
dist_x=signal_dist.dist_x,
dist_y=signal_dist.dist_y,
)

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.array([1 - inlier_fraction, inlier_fraction]),
components=[signal_dist, noise_dist],
)

sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)

task_id = f"mult-sparse-w-inliers-{dim_x}-{dim_y}-{n_interacting}-{strength}-{inlier_fraction}"
return Task(
Expand Down
30 changes: 20 additions & 10 deletions src/bmi/estimators/external/gmm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""A Gaussian mixture model estimator, allowing for model-based
Bayesian estimator of mutual information.
The full description can be found [here](https://arxiv.org/abs/2310.10240).
Note that to use this estimator you need to install external dependencies:
```bash
$ pip install benchmark-mi[bayes]
```
"""

try:
import numpyro # type: ignore
import numpyro.distributions as dist # type: ignore
Expand All @@ -12,7 +22,7 @@
from numpy.typing import ArrayLike

from bmi.interface import BaseModel, IMutualInformationPointEstimator
from bmi.samplers import fine
from bmi.samplers import bmm
from bmi.utils import ProductSpace


Expand Down Expand Up @@ -74,14 +84,14 @@ def model(
)


def sample_into_fine_distribution(
def sample_into_bmm_distribution(
means: jnp.ndarray,
covariances: jnp.ndarray,
proportions: jnp.ndarray,
dim_x: int,
dim_y: int,
) -> fine.JointDistribution:
"""Builds a fine distribution from a Gaussian mixture model parameters."""
) -> bmm.JointDistribution:
"""Builds a bmm distribution from a Gaussian mixture model parameters."""
# Check if the dimensions are right
n_components = proportions.shape[0]
n_dims = dim_x + dim_y
Expand All @@ -90,7 +100,7 @@ def sample_into_fine_distribution(

# Build components
components = [
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=dim_x,
dim_y=dim_y,
mean=mean,
Expand All @@ -100,7 +110,7 @@ def sample_into_fine_distribution(
]

# Build a mixture model
return fine.mixture(proportions=proportions, components=components)
return bmm.mixture(proportions=proportions, components=components)


class GMMEstimatorParams(BaseModel):
Expand Down Expand Up @@ -185,12 +195,12 @@ def run_mcmc(self, x: ArrayLike, y: ArrayLike):
self._dim_x = space.dim_x
self._dim_y = space.dim_y

def get_fine_distribution(self, idx: int) -> fine.JointDistribution:
def get_bmm_distribution(self, idx: int) -> bmm.JointDistribution:
if self._mcmc is None:
raise ValueError("You need to run MCMC first. See the `run_mcmc` method.")

samples = self._mcmc.get_samples()
return sample_into_fine_distribution(
return sample_into_bmm_distribution(
means=samples["mu"][idx],
covariances=samples["cov"][idx],
proportions=samples["pi"][idx],
Expand All @@ -204,8 +214,8 @@ def get_sample_mi(self, idx: int, mc_samples: Optional[int] = None, key=None) ->
if key is None:
self.key, key = jax.random.split(self.key)

distribution = self.get_fine_distribution(idx)
mi, _ = fine.monte_carlo_mi_estimate(key=key, dist=distribution, n=mc_samples)
distribution = self.get_bmm_distribution(idx)
mi, _ = bmm.monte_carlo_mi_estimate(key=key, dist=distribution, n=mc_samples)
return mi

def get_posterior_mi(
Expand Down
4 changes: 2 additions & 2 deletions src/bmi/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

# isort: on
import bmi.samplers._tfp as fine
import bmi.samplers._tfp as bmm
from bmi.samplers._independent_coordinates import IndependentConcatenationSampler
from bmi.samplers._split_student_t import SplitStudentT
from bmi.samplers._splitmultinormal import BivariateNormalSampler, SplitMultinormal
Expand All @@ -33,7 +33,7 @@
"AdditiveUniformSampler",
"BaseSampler",
"canonical_correlation",
"fine",
"bmm",
"parametrised_correlation_matrix",
"BivariateNormalSampler",
"SplitMultinormal",
Expand Down
7 changes: 4 additions & 3 deletions src/bmi/samplers/_tfp/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

@dataclasses.dataclass
class JointDistribution:
"""The main object of this package.
Represents a joint distribution $P_{XY}$ together
with the marginal distributions $P_X$ and $P_Y$.
"""The main object of this package, representing
a Bend and Mix Model (BMM), i.e., a joint distribution
$P_{XY}$ together with the marginal distributions
$P_X$ and $P_Y$.
Attributes:
dist: $P_{XY}$
Expand Down
6 changes: 3 additions & 3 deletions src/bmi/samplers/_tfp/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class FineSampler(BaseSampler):
"""Wraps a given fine distribution into a sampler."""
"""Wraps a given Bend and Mix Model (BMM) into a sampler."""

def __init__(
self,
Expand All @@ -21,8 +21,8 @@ def __init__(
"""
Args:
dist: fine distribution to be wrapped
mi: mutual information of the fine distribution, if already calculated.
dist: distribution represented by a BMM to be wrapped
mi: mutual information of the distribution, if already calculated.
If not provided, it will be estimated via Monte Carlo sampling.
mi_estimate_seed: seed for the Monte Carlo sampling
mi_estimate_sample: number of samples for the Monte Carlo sampling
Expand Down
6 changes: 3 additions & 3 deletions tests/samplers/tfp/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import pytest

import bmi.samplers
from bmi.samplers import fine
from bmi.samplers import bmm


def test_product_distribution(dim_x: int = 2, dim_y: int = 3, n_points: int = 10) -> None:
assert dim_y >= dim_x, "We construct canonical correlation matrix, so we want this constraint."

dist_dependent = fine.MultivariateNormalDistribution(
dist_dependent = bmm.MultivariateNormalDistribution(
dim_x=dim_x,
dim_y=dim_y,
covariance=bmi.samplers.canonical_correlation(
jnp.full((dim_x,), fill_value=0.5), additional_y=dim_y - dim_x
),
)
dist_independent = fine.ProductDistribution(
dist_independent = bmm.ProductDistribution(
dist_x=dist_dependent.dist_x, dist_y=dist_dependent.dist_y
)

Expand Down
6 changes: 3 additions & 3 deletions tests/samplers/tfp/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import jax.numpy as jnp
import pytest

from bmi.samplers import fine
from bmi.samplers import bmm


def test_can_create_sampler() -> None:
dist = fine.MultivariateNormalDistribution(
dist = bmm.MultivariateNormalDistribution(
dim_x=1, dim_y=1, covariance=jnp.asarray([[1, 0.5], [0.5, 1]])
)
mi = -0.5 * jnp.log(1 - 0.5**2)

sampler = fine.FineSampler(dist=dist, mi_estimate_seed=0, mi_estimate_sample=1_000)
sampler = bmm.FineSampler(dist=dist, mi_estimate_seed=0, mi_estimate_sample=1_000)

x_sample, y_sample = sampler.sample(n_points=10, rng=0)
assert x_sample.shape == (10, 1)
Expand Down

0 comments on commit a8542c0

Please sign in to comment.