Skip to content

Commit

Permalink
Posterior prediction check
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Jun 11, 2024
1 parent 77b3b39 commit 0da2035
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_theta_true(self):

def get_sigma_true(self):
if hasattr(self, "sigma_true"):
return self.sigma_true
return self.sigma_true()
else:
return get_item("data", "sigma_true", raise_exception=True)

Expand Down
2 changes: 1 addition & 1 deletion src/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ def get_theta_true(self):

def get_sigma_true(self):
try:
return super().sigma_true()
return super().get_sigma_true()
except (AssertionError, KeyError):
return 1
4 changes: 3 additions & 1 deletion src/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from plots.ranks import Ranks
from plots.tarp import TARP
from plots.local_two_sample import LocalTwoSampleTest
from plots.predictive_posterior_check import PPC

Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
"LC2ST": LocalTwoSampleTest
"LC2ST": LocalTwoSampleTest,
PPC.__name__: PPC
}
121 changes: 121 additions & 0 deletions src/plots/predictive_posterior_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Optional, Sequence
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import KFold

from plots.plot import Display
from utils.plotting_utils import get_hex_colors

class PPC(Display):
def __init__(
self,
model,
data,
save:bool,
show:bool,
out_dir:Optional[str]=None,
percentiles: Optional[Sequence] = None,
use_progress_bar: Optional[bool] = None,
samples_per_inference: Optional[int] = None,
number_simulations: Optional[int] = None,
parameter_names: Optional[Sequence] = None,
parameter_colors: Optional[Sequence]= None,
colorway: Optional[str]=None
):
super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)

def _plot_name(self):
return "predictive_posterior_check.png"

def get_posterior(self, n_simulator_draws):
context_shape = self.data.true_context().shape
self.posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference,context_shape[-1]))
self.posterior_true_samples = np.zeros_like(self.posterior_predictive_samples)
self.context = np.zeros((n_simulator_draws, context_shape[-1]))

random_context_indices = self.data.rng.integers(0, context_shape[0], n_simulator_draws)
for index, sample in enumerate(random_context_indices):
context_sample = self.data.true_context()[sample, :]
self.context[index] = context_sample

posterior_sample = self.model.sample_posterior(self.samples_per_inference, context_sample)

# get the posterior samples for that context
self.posterior_predictive_samples[index] = self.data.simulator.simulate(
theta=posterior_sample, context_samples = context_sample
)
self.posterior_true_samples[index] = self.data.simulator.simulate(
theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
)

def _plot(
self,
n_coverage_sigma: Optional[int] = 3,
true_sigma: Optional[float] = None,
theta_true_marker: Optional[str] = '^',
n_unique_plots: Optional[int] = 3,
title:str="Predictive Posterior",
y_label:str="Simulation Output",
x_label:str="X"):


self.get_posterior(n_unique_plots)
true_sigma = true_sigma if true_sigma is not None else self.data.get_sigma_true()

figure, subplots = plt.subplots(
2,
n_unique_plots,
figsize=(int(self.figure_size[0]*n_unique_plots*.6), self.figure_size[1]),
sharex=False,
sharey=True
)
colors = get_hex_colors(n_coverage_sigma, self.colorway)

for plot_index in range(n_unique_plots):

dimension_y_simulation = self.posterior_predictive_samples[plot_index]

y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel()
y_simulation_std = np.std(dimension_y_simulation, axis=0).ravel()

for sigma, color in zip(range(n_coverage_sigma), colors):
subplots[0, plot_index].fill_between(
self.context[plot_index].ravel(),
y_simulation_mean - sigma * y_simulation_std,
y_simulation_mean + sigma * y_simulation_std,
color=color,
alpha=0.6,
label=rf"Pred. with {sigma} $\sigma$",
)

subplots[0, plot_index].plot(
self.context[plot_index],
y_simulation_mean - true_sigma,
color="black",
linestyle="dashdot",
label="True Input Error"
)
subplots[0, plot_index].plot(
self.context[plot_index],
y_simulation_mean + true_sigma,
color="black",
linestyle="dashdot",
)

true_y = np.mean(self.posterior_true_samples[plot_index, :, :], axis=0).ravel()
subplots[1, plot_index].scatter(
self.context[plot_index],
true_y,
marker=theta_true_marker,
label='Theta True'
)

subplots[1, -1].legend()
subplots[0, -1].legend()

subplots[1, 0].set_ylabel("True Parameters")
subplots[0, 0].set_ylabel("Predicted Parameters")

figure.supylabel(y_label)
figure.supxlabel(x_label)
figure.suptitle(title)
8 changes: 7 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
Ranks,
CoverageFraction,
TARP,
LocalTwoSampleTest
LocalTwoSampleTest,
PPC
)


@pytest.fixture
Expand Down Expand Up @@ -63,3 +65,7 @@ def test_lc2st(plot_config, mock_model, mock_data):
plot(**get_item("plots", "LC2ST", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")

def test_ppc(plot_config, mock_model, mock_data):
plot = PPC(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "PPC", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}")

0 comments on commit 0da2035

Please sign in to comment.