diff --git a/scripts/is_ksg_estimating_pmi.py b/scripts/is_ksg_estimating_pmi.py new file mode 100644 index 00000000..0f8a1cb0 --- /dev/null +++ b/scripts/is_ksg_estimating_pmi.py @@ -0,0 +1,68 @@ +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +import bmi +from bmi.samplers import fine + +n_dim = 5 +n_points: int = 5_000 +ks = (5, 10, 20, 50) + +dist = fine.MultivariateNormalDistribution( + dim_x=n_dim, + dim_y=n_dim, + mean=jnp.zeros(2 * n_dim), + covariance=bmi.samplers.canonical_correlation([0.8] * n_dim), +) + +mi, mi_stderr = fine.monte_carlo_mi_estimate(jax.random.PRNGKey(10), dist, n=10_000) + + +xs, ys = dist.sample(n_points, jax.random.PRNGKey(42)) +pmis = dist.pmi(xs, ys) + +min_pmi = jnp.min(pmis) - 0.1 +max_pmi = jnp.max(pmis) + 0.1 + + +fig, axs = plt.subplots(len(ks), 3, figsize=(6, 2 * len(ks)), dpi=250) + +for i, k in enumerate(ks): + estimator = bmi.estimators.KSGEnsembleFirstEstimatorSlow(neighborhoods=(k,), standardize=False) + + pseudo_pmis = estimator._calculate_digammas(xs, ys, ks=(k,))[k] + + bins = jnp.linspace(min_pmi, max_pmi, 21) + + ax = axs[i, 0] + ax.hist(pmis, bins=bins, density=True) + ax.set_title("True PMI") + + ax.set_xlabel(f"$I(X; Y) = {mi:.2f}$") + ax.set_ylabel(f"$k={k}$") + + ax = axs[i, 1] + ax.hist(pseudo_pmis, bins=bins, density=True) + ax.set_title("KSG PMI") + ax.set_xlabel(f"$I(X; Y) = {np.mean(pseudo_pmis):.2f}$") + + ax = axs[i, 2] + ts = jnp.linspace(min_pmi, max_pmi, 3) + ax.plot(ts, ts, color="maroon", linestyle="--") + + ax.scatter(pmis, pseudo_pmis, s=2, alpha=0.1, c="k") + ax.set_xlabel("True PMI") + ax.set_ylabel("KSG PMI") + + ax.set_xlim(min_pmi, max_pmi) + ax.set_ylim(min_pmi, max_pmi) + ax.set_aspect("equal") + + corr = np.corrcoef(pmis, pseudo_pmis)[0, 1] + ax.set_title(f"$r={corr:.2f}$") + +fig.tight_layout() + +fig.savefig("figure.pdf") diff --git a/src/bmi/estimators/ksg.py b/src/bmi/estimators/ksg.py index 8fe142cc..661dd2a8 100644 --- a/src/bmi/estimators/ksg.py +++ b/src/bmi/estimators/ksg.py @@ -179,17 +179,19 @@ def __init__( self._fitted = False self._mi_dict = dict() # set by fit() - def fit(self, x: ArrayLike, y: ArrayLike) -> None: + def _calculate_digammas( + self, x: ArrayLike, y: ArrayLike, ks: Sequence[int] + ) -> dict[int, np.ndarray]: space = ProductSpace(x=x, y=y, standardize=self._params.standardize) n_points = len(space) - if n_points <= max(self._params.neighborhoods): + if n_points <= max(ks): raise ValueError( - f"Maximum neighborhood used is {max(self._params.neighborhoods)} " + f"Maximum neighborhood used is {max(ks)} " f"but the number of points provided is only {n_points}." ) - digammas_dict = {k: [] for k in self._params.neighborhoods} + digammas_dict = {k: [] for k in ks} for index in range(n_points): # Distances from x[index] to all the points: @@ -205,7 +207,7 @@ def fit(self, x: ArrayLike, y: ArrayLike) -> None: # And we sort the point indices by being the closest to the considered one closest_points = sorted(range(len(distances_z)), key=lambda i: distances_z[i]) - for k in self._params.neighborhoods: + for k in ks: # Note that the points are 0-indexed and that the "0th neighbor" # is the point itself (as distance(x, x) = 0 is the smallest possible) # Hence, the kth neighbour is at index k @@ -219,9 +221,16 @@ def fit(self, x: ArrayLike, y: ArrayLike) -> None: digammas_per_point = _DIGAMMA(n_x + 1) + _DIGAMMA(n_y + 1) digammas_dict[k].append(digammas_per_point) + return { + k: _DIGAMMA(k) - np.array(raw_values) + _DIGAMMA(n_points) + for k, raw_values in digammas_dict.items() + } + + def fit(self, x: ArrayLike, y: ArrayLike) -> None: + digammas_dict = self._calculate_digammas(x, y, ks=self._params.neighborhoods) + for k, digammas in digammas_dict.items(): - mi_estimate = _DIGAMMA(k) - np.mean(digammas) + _DIGAMMA(n_points) - self._mi_dict[k] = max(0.0, mi_estimate) + self._mi_dict[k] = max(0.0, np.mean(digammas)) self._fitted = True diff --git a/workflows/Mixtures/ksg_pmi_profile.smk b/workflows/Mixtures/ksg_pmi_profile.smk new file mode 100644 index 00000000..86683f62 --- /dev/null +++ b/workflows/Mixtures/ksg_pmi_profile.smk @@ -0,0 +1,163 @@ +import numpy as np +import pandas as pd +import matplotlib +from subplots_from_axsize import subplots_from_axsize +matplotlib.use("agg") + +import bmi +from bmi.samplers import fine + +import jax +import jax.numpy as jnp + + +N_SAMPLES = [100] +SEEDS = list(range(20)) + +ESTIMATORS = { + 'KSG-10': bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=(10,)), + "CCA": bmi.estimators.CCAMutualInformationEstimator(), +} +ESTIMATOR_NAMES = { + "KSG-10": "KSG", + "CCA": "CCA", +} +ESTIMATOR_COLORS = { + "KSG-10": "#d62728", + "CCA": "#1f77b4", +} +assert set(ESTIMATOR_NAMES.keys()) == set(ESTIMATORS.keys()) + + +def get_sampler(n: int) -> bmi.samplers.SplitMultinormal: + return bmi.samplers.SplitMultinormal( + dim_x=n, + dim_y=n, + covariance=bmi.samplers.canonical_correlation(rho=[0.5] * n) + ) + +UNSCALED_TASKS = { + "normal-1": bmi.benchmark.Task( + sampler=get_sampler(1), + task_id="normal-1", + task_name="Normal 1 x 1", + ), + "normal-3": bmi.benchmark.Task( + sampler=get_sampler(3), + task_id="normal-3", + task_name="Normal 3 x 3", + ), + "normal-5": bmi.benchmark.Task( + sampler=get_sampler(5), + task_id="normal-5", + task_name="Normal 5 x 5", + ), +} + +HEIGHT: float = 1.3 + +# === WORKDIR === +workdir: "generated/mixtures/aistats-rebuttal/" + +rule all: + input: + 'results.csv', + 'rebuttal_figure.pdf', + 'pmi_hist.pdf' + + + +rule plot_results: + output: 'rebuttal_figure.pdf' + input: 'results.csv' + run: + data = pd.read_csv(str(input)) + fig, ax = subplots_from_axsize(1, 1, (2, HEIGHT), right=1.3) + + data_5k = data[data['n_samples'] == 100] + tasks = ["normal-1", "normal-3", "normal-5"] + tasks_official = ["$n=1$", "$n=3$", "$n=5$"] + + for estimator_id, data_est in data_5k.groupby('estimator_id'): + ax.scatter( + data_est['task_id'].apply(lambda e: tasks.index(e)) + 0.05 * np.random.normal(size=len(data_est)), + data_est['mi_estimate'], + label=ESTIMATOR_NAMES[estimator_id], + color=ESTIMATOR_COLORS[estimator_id], + alpha=0.2, s=3**2, + rasterized=True, + ) + + _flag = True + for task_id, data_task in data_5k.groupby('task_id'): + true_mi = data_task['mi_true'].mean() + x = tasks.index(task_id) + ax.plot([x - 0.2, x + 0.2], [true_mi, true_mi], ':k', label="True MI" if _flag else None) + _flag = False + + ax.set_xticks(range(len(tasks)), tasks_official) + + ax.legend(frameon=False, loc='upper left', bbox_to_anchor=(1, 1)) + ax.spines[['top', 'right']].set_visible(False) + ax.set_ylim(-0.1, 1.3) + ax.set_ylabel('MI') + fig.savefig(str(output)) + + +rule plot_pmi_hist: + output: 'pmi_hist.pdf' + run: + n_dim: int = 1 + n_points: int = 5_000 + k: int = 20 + + dist = fine.MultivariateNormalDistribution( + dim_x=n_dim, + dim_y=n_dim, + mean=jnp.zeros(2 * n_dim), + covariance=bmi.samplers.canonical_correlation([0.8] * n_dim), + ) + + xs, ys = dist.sample(n_points, jax.random.PRNGKey(42)) + pmis = dist.pmi(xs, ys) + + min_pmi = jnp.min(pmis) - 0.1 + max_pmi = jnp.max(pmis) + 0.1 + + fig, axs = subplots_from_axsize(1, 3, (2, 1.3), wspace=[0.7, 0.3]) + + bins = jnp.linspace(-1.5, 3.5, 31) + + ax = axs[0] + ax.hist(pmis, bins=bins, density=True, alpha=0.5, color="black") + ax.set_xlabel("True PMI") + ax.set_ylabel("Frequency") + + estimator = bmi.estimators.KSGEnsembleFirstEstimatorSlow(neighborhoods=(k,), standardize=False) + pseudo_pmis = estimator._calculate_digammas(xs, ys, ks=(k,))[k] + ax = axs[1] + ax.hist(pseudo_pmis, bins=bins, density=True, alpha=0.5, color="red") + ax.set_xlabel("KSG PMI") + ax.set_ylabel("Frequency") + + ax = axs[2] + ts = jnp.linspace(min_pmi, max_pmi, 3) + ax.plot(ts, jnp.zeros_like(ts), color="darkblue", linestyle="--") + + ax.scatter(pmis, pseudo_pmis - pmis, s=2, alpha=0.1, c="k", rasterized=True) + ax.set_xlabel("True PMI") + ax.set_ylabel("KSG PMI $-$ True PMI") + + ax.set_xlim(min_pmi, max_pmi) + ax.set_ylim(min_pmi, max_pmi) + ax.set_aspect("equal") + + corr = np.corrcoef(pmis, pseudo_pmis)[0, 1] + # ax.annotate(f"$r={corr:.2f}$", xy=(0.05, 0.95), xycoords="axes fraction", ha="left", va="top") + + for ax in axs: + ax.spines[['top', 'right']].set_visible(False) + + fig.savefig(str(output)) + +include: "_benchmark_rules.smk"