Skip to content

Commit

Permalink
fix mc_calibration plot (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kucharssim authored Jan 28, 2025
1 parent 146f050 commit d990340
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 47 deletions.
27 changes: 14 additions & 13 deletions bayesflow/diagnostics/plots/mc_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,41 +68,42 @@ def mc_calibration(

# Gather plot data and metadata into a dictionary
plot_data = prepare_plot_data(
estimates=pred_models,
ground_truths=true_models,
targets=pred_models,
references=true_models,
variable_names=model_names,
num_col=num_col,
num_row=num_row,
figsize=figsize,
default_name="M",
)

# Compute calibration
cal_errors, true_probs, pred_probs = expected_calibration_error(
plot_data["ground_truths"], plot_data["estimates"], num_bins
plot_data["references"], plot_data["targets"], num_bins
)

for j, ax in enumerate(plot_data["axes"].flat):
# Plot calibration curve
ax[j].plot(pred_probs[j], true_probs[j], "o-", color=color)
ax.plot(pred_probs[j], true_probs[j], "o-", color=color)

# Plot PMP distribution over bins
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"])
ax[j].hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
norm_weights = np.ones_like(plot_data["targets"]) / len(plot_data["targets"])
ax.hist(plot_data["targets"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)

# Plot AB line
ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9)
ax.plot((0, 1), (0, 1), "--", color="black", alpha=0.9)

# Tweak plot
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_xlim([0 - epsilon, 1 + epsilon])
ax.set_ylim([0 - epsilon, 1 + epsilon])
ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

# Add ECE label
add_metric(
ax[j],
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}",
ax,
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$",
metric_value=cal_errors[j],
metric_fontsize=metric_fontsize,
)
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from .optimal_transport import optimal_transport
from .plot_utils import (
check_posterior_prior_shapes,
check_estimates_prior_shapes,
prepare_plot_data,
add_titles_and_labels,
prettify_subplots,
Expand Down
4 changes: 2 additions & 2 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import matplotlib.pyplot as plt
import seaborn as sns

from .validators import check_posterior_prior_shapes
from .validators import check_estimates_prior_shapes
from .dict_utils import dicts_to_arrays


Expand Down Expand Up @@ -52,7 +52,7 @@ def prepare_plot_data(
plot_data = dicts_to_arrays(
targets=targets, references=references, variable_names=variable_names, default_name=default_name
)
check_posterior_prior_shapes(plot_data["targets"], plot_data["references"])
check_estimates_prior_shapes(plot_data["targets"], plot_data["references"])

# Configure layout
num_row, num_col = set_layout(plot_data["num_variables"], num_row, num_col, stacked)
Expand Down
96 changes: 65 additions & 31 deletions bayesflow/utils/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,82 @@ def check_lengths_same(*args):
raise ValueError(f"All tuple arguments must have the same length, but lengths are {tuple(map(len, args))}.")


def check_posterior_prior_shapes(post_variables: Tensor, prior_variables: Tensor):
def check_prior_shapes(variables: Tensor):
"""
Checks requirements for the shapes of posterior and prior draws as
necessitated by most diagnostic functions.
Checks the shape of posterior draws as required by most diagnostic functions
Parameters
----------
post_samples : Tensor of shape (num_data_sets, num_post_draws, num_params)
The posterior draws obtained from num_data_sets
prior_samples : Tensor of shape (num_data_sets, num_params)
The prior draws obtained for generating num_data_sets
Raises
------
ShapeError
If there is a deviation form the expected shapes of `post_samples` and `prior_samples`.
variables : Tensor of shape (num_data_sets, num_params)
The prior_samples from generating num_data_sets
"""

if len(post_variables.shape) != 3:
if len(variables.shape) != 2:
raise ShapeError(
"post_samples should be a 3-dimensional array, with the "
"first dimension being the number of (simulated) data sets, "
"the second dimension being the number of posterior draws per data set, "
"and the third dimension being the number of parameters (marginal distributions), "
f"but your input has dimensions {len(post_variables.shape)}"
"prior_samples samples should be a 2-dimensional array, with the "
"first dimension being the number of (simulated) data sets / prior_samples draws "
"and the second dimension being the number of variables, "
f"but your input has dimensions {len(variables.shape)}"
)
elif len(prior_variables.shape) != 2:


def check_estimates_shapes(variables: Tensor):
"""
Checks the shape of model-generated predictions (posterior draws, point estimates)
as required by most diagnostic functions
Parameters
----------
variables : Tensor of shape (num_data_sets, num_post_draws, num_params)
The prior_samples from generating num_data_sets
"""
if len(variables.shape) != 2 and len(variables.shape) != 3:
raise ShapeError(
"prior_samples should be a 2-dimensional array, with the "
"first dimension being the number of (simulated) data sets / prior draws "
"and the second dimension being the number of parameters (marginal distributions), "
f"but your input has dimensions {len(prior_variables.shape)}"
"estimates should be a 2- or 3-dimensional array, with the "
"first dimension being the number of data sets, "
"(optional) second dimension the number of posterior draws per data set, "
"and the last dimension the number of estimated variables, "
f"but your input has dimensions {len(variables.shape)}"
)
elif post_variables.shape[0] != prior_variables.shape[0]:


def check_consistent_shapes(estimates: Tensor, prior_samples: Tensor):
"""
Checks whether the model-generated predictions (posterior draws, point estimates) and
prior_samples have consistent leading (num_data_sets) and trailing (num_params) dimensions
"""
if estimates.shape[0] != prior_samples.shape[0]:
raise ShapeError(
"The number of elements over the first dimension of post_samples and prior_samples"
f"should match, but post_samples has {post_variables.shape[0]} and prior_samples has "
f"{prior_variables.shape[0]} elements, respectively."
"The number of elements over the first dimension of estimates and prior_samples"
f"should match, but estimates have {estimates.shape[0]} and prior_samples has "
f"{prior_samples.shape[0]} elements, respectively."
)
elif post_variables.shape[-1] != prior_variables.shape[-1]:
if estimates.shape[-1] != prior_samples.shape[-1]:
raise ShapeError(
"The number of elements over the last dimension of post_samples and prior_samples"
f"should match, but post_samples has {post_variables.shape[1]} and prior_samples has "
f"{prior_variables.shape[-1]} elements, respectively."
"The number of elements over the last dimension of estimates and prior_samples"
f"should match, but estimates has {estimates.shape[0]} and prior_samples has "
f"{prior_samples.shape[0]} elements, respectively."
)


def check_estimates_prior_shapes(estimates: Tensor, prior_samples: Tensor):
"""
Checks requirements for the shapes of estimates and prior_samples draws as
necessitated by most diagnostic functions.
Parameters
----------
estimates : Tensor of shape (num_data_sets, num_post_draws, num_params) or (num_data_sets, num_params)
The model-generated predictions (posterior draws, point estimates) obtained from num_data_sets
prior_samples : Tensor of shape (num_data_sets, num_params)
The prior_samples draws obtained for generating num_data_sets
Raises
------
ShapeError
If there is a deviation form the expected shapes of `estimates` and `estimates`.
"""

check_estimates_shapes(estimates)
check_prior_shapes(prior_samples)
check_consistent_shapes(estimates, prior_samples)

0 comments on commit d990340

Please sign in to comment.