Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mc_calibration plot #294

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions bayesflow/diagnostics/plots/mc_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,41 +68,42 @@ def mc_calibration(

# Gather plot data and metadata into a dictionary
plot_data = prepare_plot_data(
estimates=pred_models,
ground_truths=true_models,
targets=pred_models,
references=true_models,
variable_names=model_names,
num_col=num_col,
num_row=num_row,
figsize=figsize,
default_name="M",
)

# Compute calibration
cal_errors, true_probs, pred_probs = expected_calibration_error(
plot_data["ground_truths"], plot_data["estimates"], num_bins
plot_data["references"], plot_data["targets"], num_bins
)

for j, ax in enumerate(plot_data["axes"].flat):
# Plot calibration curve
ax[j].plot(pred_probs[j], true_probs[j], "o-", color=color)
ax.plot(pred_probs[j], true_probs[j], "o-", color=color)

# Plot PMP distribution over bins
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"])
ax[j].hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
norm_weights = np.ones_like(plot_data["targets"]) / len(plot_data["targets"])
ax.hist(plot_data["targets"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)

# Plot AB line
ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9)
ax.plot((0, 1), (0, 1), "--", color="black", alpha=0.9)

# Tweak plot
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_xlim([0 - epsilon, 1 + epsilon])
ax.set_ylim([0 - epsilon, 1 + epsilon])
ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

# Add ECE label
add_metric(
ax[j],
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}",
ax,
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$",
metric_value=cal_errors[j],
metric_fontsize=metric_fontsize,
)
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from .optimal_transport import optimal_transport
from .plot_utils import (
check_posterior_prior_shapes,
check_estimates_prior_shapes,
prepare_plot_data,
add_titles_and_labels,
prettify_subplots,
Expand Down
4 changes: 2 additions & 2 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import matplotlib.pyplot as plt
import seaborn as sns

from .validators import check_posterior_prior_shapes
from .validators import check_estimates_prior_shapes
from .dict_utils import dicts_to_arrays


Expand Down Expand Up @@ -52,7 +52,7 @@ def prepare_plot_data(
plot_data = dicts_to_arrays(
targets=targets, references=references, variable_names=variable_names, default_name=default_name
)
check_posterior_prior_shapes(plot_data["targets"], plot_data["references"])
check_estimates_prior_shapes(plot_data["targets"], plot_data["references"])

# Configure layout
num_row, num_col = set_layout(plot_data["num_variables"], num_row, num_col, stacked)
Expand Down
96 changes: 65 additions & 31 deletions bayesflow/utils/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,82 @@ def check_lengths_same(*args):
raise ValueError(f"All tuple arguments must have the same length, but lengths are {tuple(map(len, args))}.")


def check_posterior_prior_shapes(post_variables: Tensor, prior_variables: Tensor):
def check_prior_shapes(variables: Tensor):
"""
Checks requirements for the shapes of posterior and prior draws as
necessitated by most diagnostic functions.
Checks the shape of posterior draws as required by most diagnostic functions

Parameters
----------
post_samples : Tensor of shape (num_data_sets, num_post_draws, num_params)
The posterior draws obtained from num_data_sets
prior_samples : Tensor of shape (num_data_sets, num_params)
The prior draws obtained for generating num_data_sets

Raises
------
ShapeError
If there is a deviation form the expected shapes of `post_samples` and `prior_samples`.
variables : Tensor of shape (num_data_sets, num_params)
The prior_samples from generating num_data_sets
"""

if len(post_variables.shape) != 3:
if len(variables.shape) != 2:
raise ShapeError(
"post_samples should be a 3-dimensional array, with the "
"first dimension being the number of (simulated) data sets, "
"the second dimension being the number of posterior draws per data set, "
"and the third dimension being the number of parameters (marginal distributions), "
f"but your input has dimensions {len(post_variables.shape)}"
"prior_samples samples should be a 2-dimensional array, with the "
"first dimension being the number of (simulated) data sets / prior_samples draws "
"and the second dimension being the number of variables, "
f"but your input has dimensions {len(variables.shape)}"
)
elif len(prior_variables.shape) != 2:


def check_estimates_shapes(variables: Tensor):
"""
Checks the shape of model-generated predictions (posterior draws, point estimates)
as required by most diagnostic functions

Parameters
----------
variables : Tensor of shape (num_data_sets, num_post_draws, num_params)
The prior_samples from generating num_data_sets
"""
if len(variables.shape) != 2 and len(variables.shape) != 3:
raise ShapeError(
"prior_samples should be a 2-dimensional array, with the "
"first dimension being the number of (simulated) data sets / prior draws "
"and the second dimension being the number of parameters (marginal distributions), "
f"but your input has dimensions {len(prior_variables.shape)}"
"estimates should be a 2- or 3-dimensional array, with the "
"first dimension being the number of data sets, "
"(optional) second dimension the number of posterior draws per data set, "
"and the last dimension the number of estimated variables, "
f"but your input has dimensions {len(variables.shape)}"
)
elif post_variables.shape[0] != prior_variables.shape[0]:


def check_consistent_shapes(estimates: Tensor, prior_samples: Tensor):
"""
Checks whether the model-generated predictions (posterior draws, point estimates) and
prior_samples have consistent leading (num_data_sets) and trailing (num_params) dimensions
"""
if estimates.shape[0] != prior_samples.shape[0]:
raise ShapeError(
"The number of elements over the first dimension of post_samples and prior_samples"
f"should match, but post_samples has {post_variables.shape[0]} and prior_samples has "
f"{prior_variables.shape[0]} elements, respectively."
"The number of elements over the first dimension of estimates and prior_samples"
f"should match, but estimates have {estimates.shape[0]} and prior_samples has "
f"{prior_samples.shape[0]} elements, respectively."
)
elif post_variables.shape[-1] != prior_variables.shape[-1]:
if estimates.shape[-1] != prior_samples.shape[-1]:
raise ShapeError(
"The number of elements over the last dimension of post_samples and prior_samples"
f"should match, but post_samples has {post_variables.shape[1]} and prior_samples has "
f"{prior_variables.shape[-1]} elements, respectively."
"The number of elements over the last dimension of estimates and prior_samples"
f"should match, but estimates has {estimates.shape[0]} and prior_samples has "
f"{prior_samples.shape[0]} elements, respectively."
)


def check_estimates_prior_shapes(estimates: Tensor, prior_samples: Tensor):
"""
Checks requirements for the shapes of estimates and prior_samples draws as
necessitated by most diagnostic functions.

Parameters
----------
estimates : Tensor of shape (num_data_sets, num_post_draws, num_params) or (num_data_sets, num_params)
The model-generated predictions (posterior draws, point estimates) obtained from num_data_sets
prior_samples : Tensor of shape (num_data_sets, num_params)
The prior_samples draws obtained for generating num_data_sets

Raises
------
ShapeError
If there is a deviation form the expected shapes of `estimates` and `estimates`.
"""

check_estimates_shapes(estimates)
check_prior_shapes(prior_samples)
check_consistent_shapes(estimates, prior_samples)
Loading