Skip to content

Commit

Permalink
Adjust the workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Sep 22, 2023
1 parent de2264a commit de43a3a
Showing 1 changed file with 86 additions and 24 deletions.
110 changes: 86 additions & 24 deletions workflows/Mixtures/how_good_integration_is.smk
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns

from subplots_from_axsize import subplots_from_axsize

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -49,6 +50,12 @@ ESTIMATORS: dict[str, Callable] = {
"MC": monte_carlo
}

ESTIMATOR_COLORS = {
"InfoNCE": "magenta",
"DV": "red",
"NWJ": "limegreen",
"MC": "mediumblue",
}

four_balls = fine.mixture(
proportions=jnp.array([0.3, 0.3, 0.2, 0.2]),
Expand Down Expand Up @@ -106,7 +113,8 @@ rule all:
input:
ground_truth = expand("{setup}/ground_truth.json", setup=DISTRIBUTION_AND_PMIS),
estimates=expand("{setup}/estimates.csv", setup=DISTRIBUTION_AND_PMIS),
performance_plots=expand("{setup}/performance.pdf", setup=DISTRIBUTION_AND_PMIS)
performance_plots=expand("{setup}/performance.pdf", setup=DISTRIBUTION_AND_PMIS),
plot_for_publication = "how_good_integration_is.pdf"


rule sample_distribution:
Expand Down Expand Up @@ -197,7 +205,80 @@ rule estimate_ground_truth_single_seed:
indent=4,
)

rule plot_performance:
def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path) -> None:
df = pd.read_csv(estimates_path)
with open(ground_truth_path) as fh:
ground_truth = json.load(fh)

# Add ground-truth information
x_axis =[df["n_points"].min(), df["n_points"].max()]
ax.plot(x_axis, [ground_truth["mi_mean"]] * 2, c="k", linestyle=":")
# ax.fill_between(
# x_axis,
# [ground_truth["mi_mean"] - ground_truth["mi_std"]] * 2,
# [ground_truth["mi_mean"] + ground_truth["mi_std"]] * 2,
# alpha=0.3,
# color="k",
# )

grouped = df.groupby(['n_points', 'estimator']).estimate.agg(['mean', 'std']).reset_index()
for estimator in grouped["estimator"].unique():
sub_df = grouped[grouped["estimator"] == estimator]
sub_df = sub_df.sort_values("n_points")

points = sub_df["n_points"].values
mean = sub_df["mean"].values
std = sub_df["std"].values

color = ESTIMATOR_COLORS[estimator]

ax.plot(points, mean, color=color, label=estimator)
ax.fill_between(points, mean - std, mean + std, alpha=0.1, color=color)


rule plot_performance_all:
input:
simple_ground_truth="Four_Balls/ground_truth.json",
simple_estimates = "Four_Balls/estimates.csv",
biased_ground_truth="Four_Balls_Biased/ground_truth.json",
biased_estimates = "Four_Balls_Biased/estimates.csv",
func_ground_truth="Four_Balls_SinSquare/ground_truth.json",
func_estimates = "Four_Balls_SinSquare/estimates.csv",
highdim_ground_truth="Normal-25Dim/ground_truth.json",
highdim_estimates = "Normal-25Dim/estimates.csv"
output:
"how_good_integration_is.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(2.5, 1.5), right=1.2, top=0.3)

ax = axs[0]
ax.set_title("Mixture")
plot_estimates(ax, input.simple_estimates, input.simple_ground_truth)

ax = axs[1]
ax.set_title("Constant bias")
plot_estimates(ax, input.biased_estimates, input.biased_ground_truth)

ax = axs[2]
ax.set_title("Functional bias")
plot_estimates(ax, input.func_estimates, input.func_ground_truth)

ax = axs[3]
ax.set_title("High-dimensional")
plot_estimates(ax, input.highdim_estimates, input.highdim_ground_truth)


for ax in axs:
ax.set_xlabel("Number of points")
ax.set_ylabel('Estimate')
ax.set_xscale("log", base=2)
# ax.set_xticks(df["n_points"].unique(), df["n_points"].unique())

axs[3].legend(title='Estimator', frameon=False, bbox_to_anchor=(1.01, 1), loc='upper left')

fig.savefig(str(output))

rule plot_performance_single:
input:
ground_truth="{setup}/ground_truth.json",
estimates="{setup}/estimates.csv"
Expand All @@ -209,32 +290,13 @@ rule plot_performance:

fig, ax = plt.subplots(figsize=(4, 3), dpi=150)

# Add ground-truth information
x_axis =[df["n_points"].min(), df["n_points"].max()]
ax.plot(x_axis, [ground_truth["mi_mean"]] * 2, c="k", linestyle=":")
ax.fill_between(
x_axis,
[ground_truth["mi_mean"] - ground_truth["mi_std"]] * 2,
[ground_truth["mi_mean"] + ground_truth["mi_std"]] * 2,
alpha=0.3,
color="k",
)


# Plot means for estimators
grouped = df.groupby(['n_points', 'estimator']).estimate.agg(['mean', 'std']).reset_index()
sns.lineplot(x='n_points', y='mean', hue='estimator', data=grouped, palette='tab10', ax=ax)

# Plot standard deviations
for estimator in grouped['estimator'].unique():
subset = grouped[grouped['estimator'] == estimator]
ax.fill_between(subset['n_points'], subset['mean'] - subset['std'], subset['mean'] + subset['std'], alpha=0.1)
plot_estimates(ax, input.estimates, input.ground_truth)

ax.set_xlabel("Number of points")
ax.set_ylabel('Estimate')

ax.set_xscale("log", base=2)
ax.set_xticks(df["n_points"].unique(), df["n_points"].unique())
# ax.set_xticks(df["n_points"].unique(), df["n_points"].unique())
ax.legend(title='Estimator', frameon=False)

fig.tight_layout()
Expand Down

0 comments on commit de43a3a

Please sign in to comment.