diff --git a/src/bmi/samplers/__init__.py b/src/bmi/samplers/__init__.py index 36e0bf64..2e92e703 100644 --- a/src/bmi/samplers/__init__.py +++ b/src/bmi/samplers/__init__.py @@ -12,6 +12,7 @@ ) # isort: on +import bmi.samplers._tfp as fine from bmi.samplers._split_student_t import SplitStudentT from bmi.samplers._splitmultinormal import BivariateNormalSampler, SplitMultinormal from bmi.samplers._transformed import TransformedSampler @@ -21,6 +22,7 @@ "AdditiveUniformSampler", "BaseSampler", "canonical_correlation", + "fine", "parametrised_correlation_matrix", "BivariateNormalSampler", "SplitMultinormal", diff --git a/src/bmi/samplers/_tfp/__init__.py b/src/bmi/samplers/_tfp/__init__.py index 285dd3a6..679cb43a 100644 --- a/src/bmi/samplers/_tfp/__init__.py +++ b/src/bmi/samplers/_tfp/__init__.py @@ -10,6 +10,7 @@ # isort: on from bmi.samplers._tfp._normal import MultivariateNormalDistribution from bmi.samplers._tfp._student import MultivariateStudentDistribution +from bmi.samplers._tfp._wrapper import FineSampler __all__ = [ "JointDistribution", @@ -19,4 +20,5 @@ "monte_carlo_mi_estimate", "MultivariateNormalDistribution", "MultivariateStudentDistribution", + "FineSampler", ] diff --git a/src/bmi/samplers/_tfp/_core.py b/src/bmi/samplers/_tfp/_core.py index 7c6e2de5..9ff6f256 100644 --- a/src/bmi/samplers/_tfp/_core.py +++ b/src/bmi/samplers/_tfp/_core.py @@ -31,17 +31,19 @@ class JointDistribution: dim_y: int analytic_mi: Optional[float] = None - def sample(self, key: jax.random.PRNGKeyArray, n: int) -> tuple[jnp.ndarray, jnp.ndarray]: + def sample( + self, n_points: int, key: jax.random.PRNGKeyArray + ) -> tuple[jnp.ndarray, jnp.ndarray]: """Sample from the joint distribution. Args: + n_points: number of samples to draw key: JAX random key - n: number of samples to draw """ - if n < 1: + if n_points < 1: raise ValueError("n must be positive") - xy = self.dist_joint.sample(seed=key, sample_shape=(n,)) + xy = self.dist_joint.sample(seed=key, sample_shape=(n_points,)) return xy[..., : self.dim_x], xy[..., self.dim_x :] # noqa: E203 (formatting discrepancy) def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: @@ -160,7 +162,7 @@ def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) - Returns: PMI profile, shape `(n,)` """ - x, y = dist.sample(key, n) + x, y = dist.sample(key=key, n_points=n) return dist.pmi(x, y) diff --git a/src/bmi/samplers/_tfp/_wrapper.py b/src/bmi/samplers/_tfp/_wrapper.py new file mode 100644 index 00000000..3cc386ee --- /dev/null +++ b/src/bmi/samplers/_tfp/_wrapper.py @@ -0,0 +1,48 @@ +"""A wrapper from TFP distributions to BMI samplers.""" +from typing import Optional, Union + +import jax + +from bmi.samplers._tfp._core import JointDistribution, monte_carlo_mi_estimate +from bmi.samplers.base import BaseSampler, KeyArray, cast_to_rng + + +class FineSampler(BaseSampler): + """Wrapper around a fine distribution.""" + + def __init__( + self, + dist: JointDistribution, + mi: Optional[float] = None, + mi_estimate_seed: Union[KeyArray, int] = 0, + mi_estimate_sample: int = 200_000, + ) -> None: + """ + + Args: + dist: fine distribution to be wrapped + mi: mutual information of the fine distribution, if already calculated. + If not provided, it will be estimated via Monte Carlo sampling. + mi_estimate_seed: seed for the Monte Carlo sampling + mi_estimate_sample: number of samples for the Monte Carlo sampling + """ + super().__init__(dim_x=dist.dim_x, dim_y=dist.dim_y) + self._dist = dist + + if mi is None: + rng = cast_to_rng(mi_estimate_seed) + self._mi, self._mi_stderr = monte_carlo_mi_estimate( + key=rng, dist=self._dist, n=mi_estimate_sample + ) + else: + self._mi = mi + self._mi_stderr = None + + def sample( + self, n_points: int, rng: Union[int, KeyArray] + ) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]: + key = cast_to_rng(rng) + return self._dist.sample(n_points=n_points, key=key) + + def mutual_information(self) -> float: + return self._mi diff --git a/tests/samplers/tfp/test_core.py b/tests/samplers/tfp/test_core.py index b2ab4895..1400bb3c 100644 --- a/tests/samplers/tfp/test_core.py +++ b/tests/samplers/tfp/test_core.py @@ -31,7 +31,7 @@ def distributions(dim_x: int = 2, dim_y: int = 3) -> list[bmi_tfp.JointDistribut @pytest.mark.parametrize("dist", distributions()) def test_sample_and_pmi(dist: bmi_tfp.JointDistribution, n_samples: int = 10) -> None: """Checks whether we can sample from the distribution and calculate PMI.""" - x, y = dist.sample(jax.random.PRNGKey(0), n=n_samples) + x, y = dist.sample(n_samples, jax.random.PRNGKey(0)) assert x.shape == (n_samples, dist.dim_x) assert y.shape == (n_samples, dist.dim_y) @@ -52,8 +52,8 @@ def test_transformed(dist: bmi_tfp.JointDistribution, n_points: int = 1_000) -> key = jax.random.PRNGKey(0) - x_base, y_base = base_dist.sample(key, n=n_points) - x_tran, y_tran = transformed.sample(key, n=n_points) + x_base, y_base = base_dist.sample(n_points, key) + x_tran, y_tran = transformed.sample(n_points, key) # Check shapes assert x_base.shape == x_tran.shape diff --git a/tests/samplers/tfp/test_normal.py b/tests/samplers/tfp/test_normal.py index c9269c86..747f4ece 100644 --- a/tests/samplers/tfp/test_normal.py +++ b/tests/samplers/tfp/test_normal.py @@ -12,7 +12,7 @@ def test_1v1(correlation: float = 0.5, n: int = 10): key = jax.random.PRNGKey(0) - x, y = dist.sample(key, n=n) + x, y = dist.sample(n, key) assert x.shape == (n, 1) assert y.shape == (n, 1) @@ -23,7 +23,7 @@ def test_1v1(correlation: float = 0.5, n: int = 10): == BivariateNormalSampler(correlation=correlation).mutual_information() ) # Check whether the Monte Carlo estimate is correct - estimate, _ = monte_carlo_mi_estimate(key, dist, n=5_000) + estimate, _ = monte_carlo_mi_estimate(key, dist=dist, n=5_000) assert pytest.approx(estimate, abs=0.01) == dist.analytic_mi diff --git a/tests/samplers/tfp/test_wrapper.py b/tests/samplers/tfp/test_wrapper.py new file mode 100644 index 00000000..e69de29b