From de43a3aeaa1dd21c8e359c30a5aac303b7ffd74b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Fri, 22 Sep 2023 23:11:12 +0200 Subject: [PATCH] Adjust the workflow --- .../Mixtures/how_good_integration_is.smk | 110 ++++++++++++++---- 1 file changed, 86 insertions(+), 24 deletions(-) diff --git a/workflows/Mixtures/how_good_integration_is.smk b/workflows/Mixtures/how_good_integration_is.smk index 47b434ee..f8df390a 100644 --- a/workflows/Mixtures/how_good_integration_is.smk +++ b/workflows/Mixtures/how_good_integration_is.smk @@ -11,7 +11,8 @@ import pandas as pd import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt -import seaborn as sns + +from subplots_from_axsize import subplots_from_axsize import jax import jax.numpy as jnp @@ -49,6 +50,12 @@ ESTIMATORS: dict[str, Callable] = { "MC": monte_carlo } +ESTIMATOR_COLORS = { + "InfoNCE": "magenta", + "DV": "red", + "NWJ": "limegreen", + "MC": "mediumblue", +} four_balls = fine.mixture( proportions=jnp.array([0.3, 0.3, 0.2, 0.2]), @@ -106,7 +113,8 @@ rule all: input: ground_truth = expand("{setup}/ground_truth.json", setup=DISTRIBUTION_AND_PMIS), estimates=expand("{setup}/estimates.csv", setup=DISTRIBUTION_AND_PMIS), - performance_plots=expand("{setup}/performance.pdf", setup=DISTRIBUTION_AND_PMIS) + performance_plots=expand("{setup}/performance.pdf", setup=DISTRIBUTION_AND_PMIS), + plot_for_publication = "how_good_integration_is.pdf" rule sample_distribution: @@ -197,7 +205,80 @@ rule estimate_ground_truth_single_seed: indent=4, ) -rule plot_performance: +def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path) -> None: + df = pd.read_csv(estimates_path) + with open(ground_truth_path) as fh: + ground_truth = json.load(fh) + + # Add ground-truth information + x_axis =[df["n_points"].min(), df["n_points"].max()] + ax.plot(x_axis, [ground_truth["mi_mean"]] * 2, c="k", linestyle=":") + # ax.fill_between( + # x_axis, + # [ground_truth["mi_mean"] - ground_truth["mi_std"]] * 2, + # [ground_truth["mi_mean"] + ground_truth["mi_std"]] * 2, + # alpha=0.3, + # color="k", + # ) + + grouped = df.groupby(['n_points', 'estimator']).estimate.agg(['mean', 'std']).reset_index() + for estimator in grouped["estimator"].unique(): + sub_df = grouped[grouped["estimator"] == estimator] + sub_df = sub_df.sort_values("n_points") + + points = sub_df["n_points"].values + mean = sub_df["mean"].values + std = sub_df["std"].values + + color = ESTIMATOR_COLORS[estimator] + + ax.plot(points, mean, color=color, label=estimator) + ax.fill_between(points, mean - std, mean + std, alpha=0.1, color=color) + + +rule plot_performance_all: + input: + simple_ground_truth="Four_Balls/ground_truth.json", + simple_estimates = "Four_Balls/estimates.csv", + biased_ground_truth="Four_Balls_Biased/ground_truth.json", + biased_estimates = "Four_Balls_Biased/estimates.csv", + func_ground_truth="Four_Balls_SinSquare/ground_truth.json", + func_estimates = "Four_Balls_SinSquare/estimates.csv", + highdim_ground_truth="Normal-25Dim/ground_truth.json", + highdim_estimates = "Normal-25Dim/estimates.csv" + output: + "how_good_integration_is.pdf" + run: + fig, axs = subplots_from_axsize(1, 4, axsize=(2.5, 1.5), right=1.2, top=0.3) + + ax = axs[0] + ax.set_title("Mixture") + plot_estimates(ax, input.simple_estimates, input.simple_ground_truth) + + ax = axs[1] + ax.set_title("Constant bias") + plot_estimates(ax, input.biased_estimates, input.biased_ground_truth) + + ax = axs[2] + ax.set_title("Functional bias") + plot_estimates(ax, input.func_estimates, input.func_ground_truth) + + ax = axs[3] + ax.set_title("High-dimensional") + plot_estimates(ax, input.highdim_estimates, input.highdim_ground_truth) + + + for ax in axs: + ax.set_xlabel("Number of points") + ax.set_ylabel('Estimate') + ax.set_xscale("log", base=2) + # ax.set_xticks(df["n_points"].unique(), df["n_points"].unique()) + + axs[3].legend(title='Estimator', frameon=False, bbox_to_anchor=(1.01, 1), loc='upper left') + + fig.savefig(str(output)) + +rule plot_performance_single: input: ground_truth="{setup}/ground_truth.json", estimates="{setup}/estimates.csv" @@ -209,32 +290,13 @@ rule plot_performance: fig, ax = plt.subplots(figsize=(4, 3), dpi=150) - # Add ground-truth information - x_axis =[df["n_points"].min(), df["n_points"].max()] - ax.plot(x_axis, [ground_truth["mi_mean"]] * 2, c="k", linestyle=":") - ax.fill_between( - x_axis, - [ground_truth["mi_mean"] - ground_truth["mi_std"]] * 2, - [ground_truth["mi_mean"] + ground_truth["mi_std"]] * 2, - alpha=0.3, - color="k", - ) - - - # Plot means for estimators - grouped = df.groupby(['n_points', 'estimator']).estimate.agg(['mean', 'std']).reset_index() - sns.lineplot(x='n_points', y='mean', hue='estimator', data=grouped, palette='tab10', ax=ax) - - # Plot standard deviations - for estimator in grouped['estimator'].unique(): - subset = grouped[grouped['estimator'] == estimator] - ax.fill_between(subset['n_points'], subset['mean'] - subset['std'], subset['mean'] + subset['std'], alpha=0.1) + plot_estimates(ax, input.estimates, input.ground_truth) ax.set_xlabel("Number of points") ax.set_ylabel('Estimate') ax.set_xscale("log", base=2) - ax.set_xticks(df["n_points"].unique(), df["n_points"].unique()) + # ax.set_xticks(df["n_points"].unique(), df["n_points"].unique()) ax.legend(title='Estimator', frameon=False) fig.tight_layout()