Skip to content

Commit

Permalink
Style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed May 14, 2024
1 parent 878b663 commit 369c86f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
5 changes: 3 additions & 2 deletions workflows/projects/Mixtures/figure_bmm_vs_other.smk
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@ rule generate_figure:
ax.set_xlim(*task_config.xlim)
ax.set_xticks(task_config.xticks)
ax.set_yticks([])
ax.set_ylim(-0.05, 1.01)
ax.spines[["top", "left", "right"]].set_visible(False)

mi_true = data_v2.groupby("task_id")["mi_true"].mean()[task_id]
ax.axvline(mi_true, linestyle="--", color="black", linewidth=2)
ax.axvline(mi_true, linestyle=":", color="black", linewidth=2)

# Plot credible intervals from the BMM
bmm_subtable = data_bmm[(data_bmm["task_id"] == task_id)].copy()
Expand All @@ -140,7 +141,7 @@ rule generate_figure:
index = (data_v2["task_id"] == task_id) & (data_v2["estimator_id"] == estimator_id) & (data_v2["n_samples"] == N_SAMPLES)
estimates = data_v2[index]["mi_estimate"].values
y = y_scaler.get_y(estimator_id=estimator_id, n_points=len(estimates))
ax.scatter(estimates, y, color=estimator_config.color, s=DOT_SIZE)
ax.scatter(estimates, y, color=estimator_config.color, s=DOT_SIZE, alpha=0.4)

ax = axs[0]
ax.set_yticks(y_scaler.get_tick_locations(), ["BMM"] + [config.name for config in POINT_ESTIMATORS])
Expand Down
20 changes: 12 additions & 8 deletions workflows/projects/Mixtures/fitting_gmm.smk
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ DISTRIBUTIONS = {
rule all:
# For the main part of the manuscript
input:
expand("plots/{dist_name}-{n_points}-10.pdf", dist_name=["AI", "Galaxy"], n_points=[250])
expand("plots/{dist_name}-{n_points}-10.pdf", dist_name=["AI", "Galaxy"], n_points=[500])


rule plots_all:
Expand Down Expand Up @@ -192,20 +192,22 @@ rule plot_pdf:
approx_sample = "approx_samples/{dist_name}-{n_points}-{n_components}-0.npz",
output: "plots/{dist_name}-{n_points}-{n_components}.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(1.5, 1.5), top=0.3, wspace=0.3)
fig, axs = subplots_from_axsize(1, 4, axsize=(1.2, 1.2), top=0.3, wspace=[0.3, 0.05, 0.05], left=0.5, right=0.15)

for ax in axs:
ax.spines[['right', 'top']].set_visible(False)

FONTDICT = {'fontsize': 10}

# Visualise true sample
ax = axs[0]
ax.set_title("Ground-truth sample")
ax.set_title("Ground-truth sample", fontdict=FONTDICT)
true_sample = np.load(input.true_sample)
visualise_points(true_sample["xs"], true_sample["ys"], ax)

# Visualise approximate sample
ax = axs[1]
ax.set_title("Simulated sample")
ax.set_title("Simulated sample", fontdict=FONTDICT)
approx_sample = np.load(input.approx_sample)
visualise_points(approx_sample["xs"], approx_sample["ys"], ax)

Expand All @@ -214,7 +216,7 @@ rule plot_pdf:

# Visualise posterior on mutual information
ax = axs[2]
ax.set_title("Posterior MI")
ax.set_title("Posterior MI", fontdict=FONTDICT)
mi_true = np.mean(pmi_true)
mi_approx = np.mean(pmi_approx, axis=1) # (num_mcmc_samples,)
ax.set_xlabel("MI")
Expand All @@ -225,11 +227,13 @@ rule plot_pdf:

# Visualise posterior on profile
ax = axs[3]
ax.set_title("Posterior PMI profile")
ax.set_title("Posterior PMI profile", fontdict=FONTDICT)
ax.set_xlabel("PMI")

min_val = np.min([pmi_true.min(), pmi_approx.min()])
max_val = np.max([pmi_true.max(), pmi_approx.max()])
quantile_min = 0.02
quantile_max = 1 - quantile_min
min_val = np.min([np.quantile(pmi_true, quantile_min), np.quantile(pmi_approx, quantile_min)])
max_val = np.max([np.quantile(pmi_true, quantile_max), np.quantile(pmi_approx, quantile_max)])

bins = np.linspace(min_val, max_val, 50)
for pmi_vals in pmi_approx:
Expand Down

0 comments on commit 369c86f

Please sign in to comment.