Skip to content

Commit

Permalink
Merge pull request #74 from voetberg/test_2d
Browse files Browse the repository at this point in the history
2D Test suits and corrections to existing code
  • Loading branch information
voetberg authored Jun 18, 2024
2 parents 4e21245 + c05ebdc commit 5cf472c
Show file tree
Hide file tree
Showing 10 changed files with 1,770 additions and 1,153 deletions.
2,575 changes: 1,510 additions & 1,065 deletions poetry.lock

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions src/data/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import importlib.util
import sys
import os
from typing import Optional
import numpy as np

from utils.config import get_item
Expand All @@ -14,6 +12,7 @@ def __init__(
simulator_kwargs: dict = None,
prior: str = None,
prior_kwargs: dict = None,
simulation_dimensions:Optional[int] = None,
):
self.rng = np.random.default_rng(
get_item("common", "random_seed", raise_exception=False)
Expand All @@ -22,6 +21,12 @@ def __init__(
self.simulator = load_simulator(simulator_name, simulator_kwargs)
self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.get_theta_true().shape[1]
self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False)

def get_simulator_output_shape(self):
context_shape = self.true_context().shape
sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1])
return sim_out.shape

def _load(self, path: str):
raise NotImplementedError
Expand Down
13 changes: 10 additions & 3 deletions src/data/h5_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable
from typing import Any, Callable, Optional
import h5py
import numpy as np
import torch
Expand All @@ -8,8 +8,15 @@


class H5Data(Data):
def __init__(self, path: str, simulator: Callable):
super().__init__(path, simulator)
def __init__(self,
path: str,
simulator: Callable,
simulator_kwargs: dict = None,
prior: str = None,
prior_kwargs: dict = None,
simulation_dimensions:Optional[int] = None,
):
super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions)

def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
Expand Down
13 changes: 10 additions & 3 deletions src/data/pickle_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import pickle
from typing import Any, Callable
from typing import Any, Callable, Optional
from data.data import Data


class PickleData(Data):
def __init__(self, path: str, simulator: Callable):
super().__init__(path, simulator)
def __init__(self,
path: str,
simulator: Callable,
simulator_kwargs: dict = None,
prior: str = None,
prior_kwargs: dict = None,
simulation_dimensions:Optional[int] = None,
):
super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions)

def _load(self, path: str):
assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'"
Expand Down
68 changes: 50 additions & 18 deletions src/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,43 @@ def _collect_data_params(self):
# P is the prior and x_P is generated via the simulator from the parameters P.
self.p = self.data.sample_prior(self.number_simulations)
self.q = np.zeros_like(self.p)

context_size = self.data.true_context().shape[-1]
self.outcome_given_p = np.zeros(
(self.number_simulations, context_size)
remove_first_dim = False

)
if self.data.simulator_dimensions == 1:
self.outcome_given_p = np.zeros((self.number_simulations, context_size))
elif self.data.simulator_dimensions == 2:
sim_out_shape = self.data.get_simulator_output_shape()
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

sim_out_shape = np.product(sim_out_shape)
self.outcome_given_p = np.zeros((self.number_simulations, sim_out_shape))
else:
raise NotImplementedError("LC2ST only implemented for 1 or two dimensions.")

self.outcome_given_q = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros((self.number_simulations, context_size))

for index, p in enumerate(self.p):
context = self.data.simulator.generate_context(context_size)
self.outcome_given_p[index] = self.data.simulator.simulate(p, context)
# Q is the approximate posterior amortized in x
q = self.model.sample_posterior(1, context).ravel()
self.q[index] = q
self.outcome_given_q[index] = self.data.simulator.simulate(q, context)
self.evaluation_context[index] = context

p_outcome = self.data.simulator.simulate(p, context)
q_outcome = self.data.simulator.simulate(q, context)

if remove_first_dim:
p_outcome = p_outcome[0]
q_outcome = q_outcome[0]

self.outcome_given_p[index] = p_outcome.ravel()
self.outcome_given_q[index] = q_outcome.ravel() # Q is the approximate posterior amortized in x


self.evaluation_context = np.array(
[
self.data.simulator.generate_context(context_size)
for _ in range(self.num_simulations)
]
)

def train_linear_classifier(
self, p, q, x_p, x_q, classifier: str, classifier_kwargs: dict = {}
Expand Down Expand Up @@ -127,9 +141,21 @@ def _cross_eval_score(
cv_splits = kf.split(p)
# train classifiers over cv-folds
probabilities = []
self.evaluation_data = np.zeros(
(n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1])
)

remove_first_dim = False
if self.data.simulator_dimensions == 1:
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1]))

elif self.data.simulator_dimensions == 2:
sim_out_shape = self.data.get_simulator_output_shape()
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

sim_out_shape = np.product(sim_out_shape)
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), sim_out_shape))

self.prior_evaluation = np.zeros_like(p)

kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42)
Expand All @@ -142,10 +168,16 @@ def _cross_eval_score(
p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs
)
p_evaluate = p[val_index]

for index, p_validation in enumerate(p_evaluate):
self.evaluation_data[cross_trial][index] = self.data.simulator.simulate(
sim_output = self.data.simulator.simulate(
p_validation, self.evaluation_context[val_index][index]
)

if remove_first_dim:
sim_output = sim_output[0]
self.evaluation_data[cross_trial][index] = sim_output.ravel()

self.prior_evaluation[index] = p_validation
probabilities.append(
self._eval_model(
Expand Down
145 changes: 102 additions & 43 deletions src/plots/predictive_posterior_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,41 @@ def __init__(
def _plot_name(self):
return "predictive_posterior_check.png"

def get_posterior(self, n_simulator_draws):
def get_posterior_2d(self, n_simulator_draws):
context_shape = self.data.true_context().shape
self.posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference,context_shape[-1]))
sim_out_shape = self.data.get_simulator_output_shape()
remove_first_dim = False
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

self.posterior_predictive_samples = np.zeros((n_simulator_draws, *sim_out_shape))
self.posterior_true_samples = np.zeros_like(self.posterior_predictive_samples)

random_context_indices = self.data.rng.integers(0, context_shape[0], n_simulator_draws)
for index, sample in enumerate(random_context_indices):
context_sample = self.data.true_context()[sample, :]
posterior_sample = self.model.sample_posterior(1, context_sample)

# get the posterior samples for that context
sim_out_posterior = self.data.simulator.simulate(
theta=posterior_sample, context_samples = context_sample
)
sim_out_true = self.data.simulator.simulate(
theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
)
if remove_first_dim:
sim_out_posterior = sim_out_posterior[0]
sim_out_true = sim_out_true[0]

self.posterior_predictive_samples[index] = sim_out_posterior
self.posterior_true_samples[index] = sim_out_true


def get_posterior_1d(self, n_simulator_draws):
context_shape = self.data.true_context().shape
self.posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, context_shape[-1]))
self.posterior_true_samples = np.zeros_like(self.posterior_predictive_samples)
self.context = np.zeros((n_simulator_draws, context_shape[-1]))

Expand All @@ -48,70 +80,97 @@ def get_posterior(self, n_simulator_draws):
theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
)

def _plot_1d(self,
subplots: np.ndarray,
subplot_index: int,
n_coverage_sigma: Optional[int] = 3,
theta_true_marker: Optional[str] = '^'
):

dimension_y_simulation = self.posterior_predictive_samples[subplot_index]
y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel()
y_simulation_std = np.std(dimension_y_simulation, axis=0).ravel()

for sigma, color in zip(range(n_coverage_sigma), self.colors):
subplots[0, subplot_index].fill_between(
self.context[subplot_index].ravel(),
y_simulation_mean - sigma * y_simulation_std,
y_simulation_mean + sigma * y_simulation_std,
color=color,
alpha=0.6,
label=rf"Pred. with {sigma} $\sigma$",
)

subplots[0, subplot_index].plot(
self.context[subplot_index],
y_simulation_mean - self.true_sigma,
color="black",
linestyle="dashdot",
label="True Input Error"
)
subplots[0, subplot_index].plot(
self.context[subplot_index],
y_simulation_mean + self.true_sigma,
color="black",
linestyle="dashdot",
)

true_y = np.mean(self.posterior_true_samples[subplot_index, :, :], axis=0).ravel()
subplots[1, subplot_index].scatter(
self.context[subplot_index],
true_y,
marker=theta_true_marker,
label='Theta True'
)

def _plot_2d(self, subplots, subplot_index, include_axis_ticks):
subplots[1, subplot_index].imshow(self.posterior_predictive_samples[subplot_index])
subplots[0, subplot_index].imshow(self.posterior_true_samples[subplot_index])

if not include_axis_ticks:
subplots[1, subplot_index].set_xticks([])
subplots[1, subplot_index].set_yticks([])

subplots[0, subplot_index].set_xticks([])
subplots[0, subplot_index].set_yticks([])

def _plot(
self,
n_coverage_sigma: Optional[int] = 3,
true_sigma: Optional[float] = None,
theta_true_marker: Optional[str] = '^',
n_unique_plots: Optional[int] = 3,
include_axis_ticks: bool = False,
title:str="Predictive Posterior",
y_label:str="Simulation Output",
x_label:str="X"):

if self.data.simulator_dimensions == 1:
self.get_posterior_1d(n_unique_plots)
self.true_sigma = true_sigma if true_sigma is not None else self.data.get_sigma_true()
self.colors = get_hex_colors(n_coverage_sigma, self.colorway)

self.get_posterior(n_unique_plots)
true_sigma = true_sigma if true_sigma is not None else self.data.get_sigma_true()
elif self.data.simulator_dimensions == 2:
self.get_posterior_2d(n_unique_plots)

else:
raise NotImplementedError("Posterior Checks only implemented for 1 or two dimensions.")

figure, subplots = plt.subplots(
2,
n_unique_plots,
figsize=(int(self.figure_size[0]*n_unique_plots*.6), self.figure_size[1]),
sharex=False,
sharey=True
)
colors = get_hex_colors(n_coverage_sigma, self.colorway)

for plot_index in range(n_unique_plots):
if self.data.simulator_dimensions == 1:
self._plot_1d(subplots, plot_index, n_coverage_sigma, theta_true_marker)

dimension_y_simulation = self.posterior_predictive_samples[plot_index]

y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel()
y_simulation_std = np.std(dimension_y_simulation, axis=0).ravel()

for sigma, color in zip(range(n_coverage_sigma), colors):
subplots[0, plot_index].fill_between(
self.context[plot_index].ravel(),
y_simulation_mean - sigma * y_simulation_std,
y_simulation_mean + sigma * y_simulation_std,
color=color,
alpha=0.6,
label=rf"Pred. with {sigma} $\sigma$",
)

subplots[0, plot_index].plot(
self.context[plot_index],
y_simulation_mean - true_sigma,
color="black",
linestyle="dashdot",
label="True Input Error"
)
subplots[0, plot_index].plot(
self.context[plot_index],
y_simulation_mean + true_sigma,
color="black",
linestyle="dashdot",
)

true_y = np.mean(self.posterior_true_samples[plot_index, :, :], axis=0).ravel()
subplots[1, plot_index].scatter(
self.context[plot_index],
true_y,
marker=theta_true_marker,
label='Theta True'
)
else:
self._plot_2d(subplots, plot_index, include_axis_ticks)

subplots[1, -1].legend()
subplots[0, -1].legend()

subplots[1, 0].set_ylabel("True Parameters")
subplots[0, 0].set_ylabel("Predicted Parameters")
Expand Down
4 changes: 3 additions & 1 deletion src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"prior": "normal",
"prior_kwargs": None,
"simulator_kwargs": None,
"simulator_dimensions": 1,
},
"plots_common": {
"axis_spines": False,
Expand All @@ -29,7 +30,8 @@
"TARP": {
"coverage_sigma": 3 # How many sigma to show coverage over
},
"LC2ST": {}
"LC2ST": {},
"PPC":{}
},
"metrics_common": {
"use_progress_bar": False,
Expand Down
6 changes: 4 additions & 2 deletions src/utils/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ def register_simulator(simulator_name, simulator):
if not os.path.exists(sim_paths):
open(sim_paths, 'a').close()

with open(sim_paths, "w+") as f:
with open(sim_paths, "r+") as f:
try:
existing_sims = json.load(f)
except json.decoder.JSONDecodeError:
existing_sims = {}


existing_sims[simulator_name] = simulator_location
with open(sim_paths, "w") as f:
existing_sims[simulator_name] = simulator_location
json.dump(existing_sims, f)

Expand Down
Loading

0 comments on commit 5cf472c

Please sign in to comment.