From d990340c641ed02168fab9ea2ebf87c3d010b943 Mon Sep 17 00:00:00 2001 From: Simon Kucharsky Date: Tue, 28 Jan 2025 11:50:00 +0100 Subject: [PATCH] fix mc_calibration plot (#294) --- bayesflow/diagnostics/plots/mc_calibration.py | 27 +++--- bayesflow/utils/__init__.py | 2 +- bayesflow/utils/plot_utils.py | 4 +- bayesflow/utils/validators.py | 96 +++++++++++++------ 4 files changed, 82 insertions(+), 47 deletions(-) diff --git a/bayesflow/diagnostics/plots/mc_calibration.py b/bayesflow/diagnostics/plots/mc_calibration.py index 8a1972e5f..6c78170a3 100644 --- a/bayesflow/diagnostics/plots/mc_calibration.py +++ b/bayesflow/diagnostics/plots/mc_calibration.py @@ -68,41 +68,42 @@ def mc_calibration( # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( - estimates=pred_models, - ground_truths=true_models, + targets=pred_models, + references=true_models, variable_names=model_names, num_col=num_col, num_row=num_row, figsize=figsize, + default_name="M", ) # Compute calibration cal_errors, true_probs, pred_probs = expected_calibration_error( - plot_data["ground_truths"], plot_data["estimates"], num_bins + plot_data["references"], plot_data["targets"], num_bins ) for j, ax in enumerate(plot_data["axes"].flat): # Plot calibration curve - ax[j].plot(pred_probs[j], true_probs[j], "o-", color=color) + ax.plot(pred_probs[j], true_probs[j], "o-", color=color) # Plot PMP distribution over bins uniform_bins = np.linspace(0.0, 1.0, num_bins + 1) - norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"]) - ax[j].hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3) + norm_weights = np.ones_like(plot_data["targets"]) / len(plot_data["targets"]) + ax.hist(plot_data["targets"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3) # Plot AB line - ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9) + ax.plot((0, 1), (0, 1), "--", color="black", alpha=0.9) # Tweak plot - ax[j].set_xlim([0 - epsilon, 1 + epsilon]) - ax[j].set_ylim([0 - epsilon, 1 + epsilon]) - ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) - ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + ax.set_xlim([0 - epsilon, 1 + epsilon]) + ax.set_ylim([0 - epsilon, 1 + epsilon]) + ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) # Add ECE label add_metric( - ax[j], - metric_text=r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}", + ax, + metric_text=r"$\widehat{{\mathrm{{ECE}}}}$", metric_value=cal_errors[j], metric_fontsize=metric_fontsize, ) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index d86827899..a4b084388 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -35,7 +35,7 @@ ) from .optimal_transport import optimal_transport from .plot_utils import ( - check_posterior_prior_shapes, + check_estimates_prior_shapes, prepare_plot_data, add_titles_and_labels, prettify_subplots, diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index a1ddbeed6..11c79995a 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import seaborn as sns -from .validators import check_posterior_prior_shapes +from .validators import check_estimates_prior_shapes from .dict_utils import dicts_to_arrays @@ -52,7 +52,7 @@ def prepare_plot_data( plot_data = dicts_to_arrays( targets=targets, references=references, variable_names=variable_names, default_name=default_name ) - check_posterior_prior_shapes(plot_data["targets"], plot_data["references"]) + check_estimates_prior_shapes(plot_data["targets"], plot_data["references"]) # Configure layout num_row, num_col = set_layout(plot_data["num_variables"], num_row, num_col, stacked) diff --git a/bayesflow/utils/validators.py b/bayesflow/utils/validators.py index f40998fb9..2eae6391b 100644 --- a/bayesflow/utils/validators.py +++ b/bayesflow/utils/validators.py @@ -8,48 +8,82 @@ def check_lengths_same(*args): raise ValueError(f"All tuple arguments must have the same length, but lengths are {tuple(map(len, args))}.") -def check_posterior_prior_shapes(post_variables: Tensor, prior_variables: Tensor): +def check_prior_shapes(variables: Tensor): """ - Checks requirements for the shapes of posterior and prior draws as - necessitated by most diagnostic functions. + Checks the shape of posterior draws as required by most diagnostic functions Parameters ---------- - post_samples : Tensor of shape (num_data_sets, num_post_draws, num_params) - The posterior draws obtained from num_data_sets - prior_samples : Tensor of shape (num_data_sets, num_params) - The prior draws obtained for generating num_data_sets - - Raises - ------ - ShapeError - If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. + variables : Tensor of shape (num_data_sets, num_params) + The prior_samples from generating num_data_sets """ - if len(post_variables.shape) != 3: + if len(variables.shape) != 2: raise ShapeError( - "post_samples should be a 3-dimensional array, with the " - "first dimension being the number of (simulated) data sets, " - "the second dimension being the number of posterior draws per data set, " - "and the third dimension being the number of parameters (marginal distributions), " - f"but your input has dimensions {len(post_variables.shape)}" + "prior_samples samples should be a 2-dimensional array, with the " + "first dimension being the number of (simulated) data sets / prior_samples draws " + "and the second dimension being the number of variables, " + f"but your input has dimensions {len(variables.shape)}" ) - elif len(prior_variables.shape) != 2: + + +def check_estimates_shapes(variables: Tensor): + """ + Checks the shape of model-generated predictions (posterior draws, point estimates) + as required by most diagnostic functions + + Parameters + ---------- + variables : Tensor of shape (num_data_sets, num_post_draws, num_params) + The prior_samples from generating num_data_sets + """ + if len(variables.shape) != 2 and len(variables.shape) != 3: raise ShapeError( - "prior_samples should be a 2-dimensional array, with the " - "first dimension being the number of (simulated) data sets / prior draws " - "and the second dimension being the number of parameters (marginal distributions), " - f"but your input has dimensions {len(prior_variables.shape)}" + "estimates should be a 2- or 3-dimensional array, with the " + "first dimension being the number of data sets, " + "(optional) second dimension the number of posterior draws per data set, " + "and the last dimension the number of estimated variables, " + f"but your input has dimensions {len(variables.shape)}" ) - elif post_variables.shape[0] != prior_variables.shape[0]: + + +def check_consistent_shapes(estimates: Tensor, prior_samples: Tensor): + """ + Checks whether the model-generated predictions (posterior draws, point estimates) and + prior_samples have consistent leading (num_data_sets) and trailing (num_params) dimensions + """ + if estimates.shape[0] != prior_samples.shape[0]: raise ShapeError( - "The number of elements over the first dimension of post_samples and prior_samples" - f"should match, but post_samples has {post_variables.shape[0]} and prior_samples has " - f"{prior_variables.shape[0]} elements, respectively." + "The number of elements over the first dimension of estimates and prior_samples" + f"should match, but estimates have {estimates.shape[0]} and prior_samples has " + f"{prior_samples.shape[0]} elements, respectively." ) - elif post_variables.shape[-1] != prior_variables.shape[-1]: + if estimates.shape[-1] != prior_samples.shape[-1]: raise ShapeError( - "The number of elements over the last dimension of post_samples and prior_samples" - f"should match, but post_samples has {post_variables.shape[1]} and prior_samples has " - f"{prior_variables.shape[-1]} elements, respectively." + "The number of elements over the last dimension of estimates and prior_samples" + f"should match, but estimates has {estimates.shape[0]} and prior_samples has " + f"{prior_samples.shape[0]} elements, respectively." ) + + +def check_estimates_prior_shapes(estimates: Tensor, prior_samples: Tensor): + """ + Checks requirements for the shapes of estimates and prior_samples draws as + necessitated by most diagnostic functions. + + Parameters + ---------- + estimates : Tensor of shape (num_data_sets, num_post_draws, num_params) or (num_data_sets, num_params) + The model-generated predictions (posterior draws, point estimates) obtained from num_data_sets + prior_samples : Tensor of shape (num_data_sets, num_params) + The prior_samples draws obtained for generating num_data_sets + + Raises + ------ + ShapeError + If there is a deviation form the expected shapes of `estimates` and `estimates`. + """ + + check_estimates_shapes(estimates) + check_prior_shapes(prior_samples) + check_consistent_shapes(estimates, prior_samples)