From 369c86fe77d817821f85de0319d4cb3daecf7dec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 14 May 2024 13:40:37 +0200 Subject: [PATCH] Style changes --- .../projects/Mixtures/figure_bmm_vs_other.smk | 5 +++-- workflows/projects/Mixtures/fitting_gmm.smk | 20 +++++++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/workflows/projects/Mixtures/figure_bmm_vs_other.smk b/workflows/projects/Mixtures/figure_bmm_vs_other.smk index 71b1455e..f7422a45 100644 --- a/workflows/projects/Mixtures/figure_bmm_vs_other.smk +++ b/workflows/projects/Mixtures/figure_bmm_vs_other.smk @@ -119,10 +119,11 @@ rule generate_figure: ax.set_xlim(*task_config.xlim) ax.set_xticks(task_config.xticks) ax.set_yticks([]) + ax.set_ylim(-0.05, 1.01) ax.spines[["top", "left", "right"]].set_visible(False) mi_true = data_v2.groupby("task_id")["mi_true"].mean()[task_id] - ax.axvline(mi_true, linestyle="--", color="black", linewidth=2) + ax.axvline(mi_true, linestyle=":", color="black", linewidth=2) # Plot credible intervals from the BMM bmm_subtable = data_bmm[(data_bmm["task_id"] == task_id)].copy() @@ -140,7 +141,7 @@ rule generate_figure: index = (data_v2["task_id"] == task_id) & (data_v2["estimator_id"] == estimator_id) & (data_v2["n_samples"] == N_SAMPLES) estimates = data_v2[index]["mi_estimate"].values y = y_scaler.get_y(estimator_id=estimator_id, n_points=len(estimates)) - ax.scatter(estimates, y, color=estimator_config.color, s=DOT_SIZE) + ax.scatter(estimates, y, color=estimator_config.color, s=DOT_SIZE, alpha=0.4) ax = axs[0] ax.set_yticks(y_scaler.get_tick_locations(), ["BMM"] + [config.name for config in POINT_ESTIMATORS]) diff --git a/workflows/projects/Mixtures/fitting_gmm.smk b/workflows/projects/Mixtures/fitting_gmm.smk index 862e99ce..672b9fbc 100644 --- a/workflows/projects/Mixtures/fitting_gmm.smk +++ b/workflows/projects/Mixtures/fitting_gmm.smk @@ -117,7 +117,7 @@ DISTRIBUTIONS = { rule all: # For the main part of the manuscript input: - expand("plots/{dist_name}-{n_points}-10.pdf", dist_name=["AI", "Galaxy"], n_points=[250]) + expand("plots/{dist_name}-{n_points}-10.pdf", dist_name=["AI", "Galaxy"], n_points=[500]) rule plots_all: @@ -192,20 +192,22 @@ rule plot_pdf: approx_sample = "approx_samples/{dist_name}-{n_points}-{n_components}-0.npz", output: "plots/{dist_name}-{n_points}-{n_components}.pdf" run: - fig, axs = subplots_from_axsize(1, 4, axsize=(1.5, 1.5), top=0.3, wspace=0.3) + fig, axs = subplots_from_axsize(1, 4, axsize=(1.2, 1.2), top=0.3, wspace=[0.3, 0.05, 0.05], left=0.5, right=0.15) for ax in axs: ax.spines[['right', 'top']].set_visible(False) + FONTDICT = {'fontsize': 10} + # Visualise true sample ax = axs[0] - ax.set_title("Ground-truth sample") + ax.set_title("Ground-truth sample", fontdict=FONTDICT) true_sample = np.load(input.true_sample) visualise_points(true_sample["xs"], true_sample["ys"], ax) # Visualise approximate sample ax = axs[1] - ax.set_title("Simulated sample") + ax.set_title("Simulated sample", fontdict=FONTDICT) approx_sample = np.load(input.approx_sample) visualise_points(approx_sample["xs"], approx_sample["ys"], ax) @@ -214,7 +216,7 @@ rule plot_pdf: # Visualise posterior on mutual information ax = axs[2] - ax.set_title("Posterior MI") + ax.set_title("Posterior MI", fontdict=FONTDICT) mi_true = np.mean(pmi_true) mi_approx = np.mean(pmi_approx, axis=1) # (num_mcmc_samples,) ax.set_xlabel("MI") @@ -225,11 +227,13 @@ rule plot_pdf: # Visualise posterior on profile ax = axs[3] - ax.set_title("Posterior PMI profile") + ax.set_title("Posterior PMI profile", fontdict=FONTDICT) ax.set_xlabel("PMI") - min_val = np.min([pmi_true.min(), pmi_approx.min()]) - max_val = np.max([pmi_true.max(), pmi_approx.max()]) + quantile_min = 0.02 + quantile_max = 1 - quantile_min + min_val = np.min([np.quantile(pmi_true, quantile_min), np.quantile(pmi_approx, quantile_min)]) + max_val = np.max([np.quantile(pmi_true, quantile_max), np.quantile(pmi_approx, quantile_max)]) bins = np.linspace(min_val, max_val, 50) for pmi_vals in pmi_approx: