Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2D Test suits and corrections to existing code #74

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,575 changes: 1,510 additions & 1,065 deletions poetry.lock

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions src/data/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import importlib.util
import sys
import os
from typing import Optional
import numpy as np

from utils.config import get_item
Expand All @@ -14,6 +12,7 @@ def __init__(
simulator_kwargs: dict = None,
prior: str = None,
prior_kwargs: dict = None,
simulation_dimensions:Optional[int] = None,
):
self.rng = np.random.default_rng(
get_item("common", "random_seed", raise_exception=False)
Expand All @@ -22,6 +21,12 @@ def __init__(
self.simulator = load_simulator(simulator_name, simulator_kwargs)
self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.get_theta_true().shape[1]
self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False)

def get_simulator_output_shape(self):
context_shape = self.true_context().shape
sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1])
return sim_out.shape

def _load(self, path: str):
raise NotImplementedError
Expand Down
13 changes: 10 additions & 3 deletions src/data/h5_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable
from typing import Any, Callable, Optional
import h5py
import numpy as np
import torch
Expand All @@ -8,8 +8,15 @@


class H5Data(Data):
def __init__(self, path: str, simulator: Callable):
super().__init__(path, simulator)
def __init__(self,
path: str,
simulator: Callable,
simulator_kwargs: dict = None,
prior: str = None,
prior_kwargs: dict = None,
simulation_dimensions:Optional[int] = None,
):
super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions)

def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
Expand Down
13 changes: 10 additions & 3 deletions src/data/pickle_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import pickle
from typing import Any, Callable
from typing import Any, Callable, Optional
from data.data import Data


class PickleData(Data):
def __init__(self, path: str, simulator: Callable):
super().__init__(path, simulator)
def __init__(self,
path: str,
simulator: Callable,
simulator_kwargs: dict = None,
prior: str = None,
prior_kwargs: dict = None,
simulation_dimensions:Optional[int] = None,
):
super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions)

def _load(self, path: str):
assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'"
Expand Down
68 changes: 50 additions & 18 deletions src/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,43 @@ def _collect_data_params(self):
# P is the prior and x_P is generated via the simulator from the parameters P.
self.p = self.data.sample_prior(self.number_simulations)
self.q = np.zeros_like(self.p)

context_size = self.data.true_context().shape[-1]
self.outcome_given_p = np.zeros(
(self.number_simulations, context_size)
remove_first_dim = False

)
if self.data.simulator_dimensions == 1:
self.outcome_given_p = np.zeros((self.number_simulations, context_size))
elif self.data.simulator_dimensions == 2:
sim_out_shape = self.data.get_simulator_output_shape()
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

sim_out_shape = np.product(sim_out_shape)
self.outcome_given_p = np.zeros((self.number_simulations, sim_out_shape))
else:
raise NotImplementedError("LC2ST only implemented for 1 or two dimensions.")

self.outcome_given_q = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros_like(self.outcome_given_p)
self.evaluation_context = np.zeros((self.number_simulations, context_size))

for index, p in enumerate(self.p):
context = self.data.simulator.generate_context(context_size)
self.outcome_given_p[index] = self.data.simulator.simulate(p, context)
# Q is the approximate posterior amortized in x
q = self.model.sample_posterior(1, context).ravel()
self.q[index] = q
self.outcome_given_q[index] = self.data.simulator.simulate(q, context)
self.evaluation_context[index] = context

p_outcome = self.data.simulator.simulate(p, context)
q_outcome = self.data.simulator.simulate(q, context)

if remove_first_dim:
p_outcome = p_outcome[0]
q_outcome = q_outcome[0]

self.outcome_given_p[index] = p_outcome.ravel()
self.outcome_given_q[index] = q_outcome.ravel() # Q is the approximate posterior amortized in x


self.evaluation_context = np.array(
[
self.data.simulator.generate_context(context_size)
for _ in range(self.num_simulations)
]
)

def train_linear_classifier(
self, p, q, x_p, x_q, classifier: str, classifier_kwargs: dict = {}
Expand Down Expand Up @@ -127,9 +141,21 @@ def _cross_eval_score(
cv_splits = kf.split(p)
# train classifiers over cv-folds
probabilities = []
self.evaluation_data = np.zeros(
(n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1])
)

remove_first_dim = False
if self.data.simulator_dimensions == 1:
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1]))

elif self.data.simulator_dimensions == 2:
sim_out_shape = self.data.get_simulator_output_shape()
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

sim_out_shape = np.product(sim_out_shape)
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), sim_out_shape))

self.prior_evaluation = np.zeros_like(p)

kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42)
Expand All @@ -142,10 +168,16 @@ def _cross_eval_score(
p_train, q_train, x_p_train, x_q_train, classifier, classifier_kwargs
)
p_evaluate = p[val_index]

for index, p_validation in enumerate(p_evaluate):
self.evaluation_data[cross_trial][index] = self.data.simulator.simulate(
sim_output = self.data.simulator.simulate(
p_validation, self.evaluation_context[val_index][index]
)

if remove_first_dim:
sim_output = sim_output[0]
self.evaluation_data[cross_trial][index] = sim_output.ravel()

self.prior_evaluation[index] = p_validation
probabilities.append(
self._eval_model(
Expand Down
145 changes: 102 additions & 43 deletions src/plots/predictive_posterior_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,41 @@ def __init__(
def _plot_name(self):
return "predictive_posterior_check.png"

def get_posterior(self, n_simulator_draws):
def get_posterior_2d(self, n_simulator_draws):
context_shape = self.data.true_context().shape
self.posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference,context_shape[-1]))
sim_out_shape = self.data.get_simulator_output_shape()
remove_first_dim = False
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True

self.posterior_predictive_samples = np.zeros((n_simulator_draws, *sim_out_shape))
self.posterior_true_samples = np.zeros_like(self.posterior_predictive_samples)

random_context_indices = self.data.rng.integers(0, context_shape[0], n_simulator_draws)
for index, sample in enumerate(random_context_indices):
context_sample = self.data.true_context()[sample, :]
posterior_sample = self.model.sample_posterior(1, context_sample)

# get the posterior samples for that context
sim_out_posterior = self.data.simulator.simulate(
theta=posterior_sample, context_samples = context_sample
)
sim_out_true = self.data.simulator.simulate(
theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
)
if remove_first_dim:
sim_out_posterior = sim_out_posterior[0]
sim_out_true = sim_out_true[0]

self.posterior_predictive_samples[index] = sim_out_posterior
self.posterior_true_samples[index] = sim_out_true


def get_posterior_1d(self, n_simulator_draws):
context_shape = self.data.true_context().shape
self.posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, context_shape[-1]))
self.posterior_true_samples = np.zeros_like(self.posterior_predictive_samples)
self.context = np.zeros((n_simulator_draws, context_shape[-1]))

Expand All @@ -48,70 +80,97 @@ def get_posterior(self, n_simulator_draws):
theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
)

def _plot_1d(self,
subplots: np.ndarray,
subplot_index: int,
n_coverage_sigma: Optional[int] = 3,
theta_true_marker: Optional[str] = '^'
):

dimension_y_simulation = self.posterior_predictive_samples[subplot_index]
y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel()
y_simulation_std = np.std(dimension_y_simulation, axis=0).ravel()

for sigma, color in zip(range(n_coverage_sigma), self.colors):
subplots[0, subplot_index].fill_between(
self.context[subplot_index].ravel(),
y_simulation_mean - sigma * y_simulation_std,
y_simulation_mean + sigma * y_simulation_std,
color=color,
alpha=0.6,
label=rf"Pred. with {sigma} $\sigma$",
)

subplots[0, subplot_index].plot(
self.context[subplot_index],
y_simulation_mean - self.true_sigma,
color="black",
linestyle="dashdot",
label="True Input Error"
)
subplots[0, subplot_index].plot(
self.context[subplot_index],
y_simulation_mean + self.true_sigma,
color="black",
linestyle="dashdot",
)

true_y = np.mean(self.posterior_true_samples[subplot_index, :, :], axis=0).ravel()
subplots[1, subplot_index].scatter(
self.context[subplot_index],
true_y,
marker=theta_true_marker,
label='Theta True'
)

def _plot_2d(self, subplots, subplot_index, include_axis_ticks):
subplots[1, subplot_index].imshow(self.posterior_predictive_samples[subplot_index])
subplots[0, subplot_index].imshow(self.posterior_true_samples[subplot_index])

if not include_axis_ticks:
subplots[1, subplot_index].set_xticks([])
subplots[1, subplot_index].set_yticks([])

subplots[0, subplot_index].set_xticks([])
subplots[0, subplot_index].set_yticks([])

def _plot(
self,
n_coverage_sigma: Optional[int] = 3,
true_sigma: Optional[float] = None,
theta_true_marker: Optional[str] = '^',
n_unique_plots: Optional[int] = 3,
include_axis_ticks: bool = False,
title:str="Predictive Posterior",
y_label:str="Simulation Output",
x_label:str="X"):

if self.data.simulator_dimensions == 1:
self.get_posterior_1d(n_unique_plots)
self.true_sigma = true_sigma if true_sigma is not None else self.data.get_sigma_true()
self.colors = get_hex_colors(n_coverage_sigma, self.colorway)

self.get_posterior(n_unique_plots)
true_sigma = true_sigma if true_sigma is not None else self.data.get_sigma_true()
elif self.data.simulator_dimensions == 2:
self.get_posterior_2d(n_unique_plots)

else:
raise NotImplementedError("Posterior Checks only implemented for 1 or two dimensions.")

figure, subplots = plt.subplots(
2,
n_unique_plots,
figsize=(int(self.figure_size[0]*n_unique_plots*.6), self.figure_size[1]),
sharex=False,
sharey=True
)
colors = get_hex_colors(n_coverage_sigma, self.colorway)

for plot_index in range(n_unique_plots):
if self.data.simulator_dimensions == 1:
self._plot_1d(subplots, plot_index, n_coverage_sigma, theta_true_marker)

dimension_y_simulation = self.posterior_predictive_samples[plot_index]

y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel()
y_simulation_std = np.std(dimension_y_simulation, axis=0).ravel()

for sigma, color in zip(range(n_coverage_sigma), colors):
subplots[0, plot_index].fill_between(
self.context[plot_index].ravel(),
y_simulation_mean - sigma * y_simulation_std,
y_simulation_mean + sigma * y_simulation_std,
color=color,
alpha=0.6,
label=rf"Pred. with {sigma} $\sigma$",
)

subplots[0, plot_index].plot(
self.context[plot_index],
y_simulation_mean - true_sigma,
color="black",
linestyle="dashdot",
label="True Input Error"
)
subplots[0, plot_index].plot(
self.context[plot_index],
y_simulation_mean + true_sigma,
color="black",
linestyle="dashdot",
)

true_y = np.mean(self.posterior_true_samples[plot_index, :, :], axis=0).ravel()
subplots[1, plot_index].scatter(
self.context[plot_index],
true_y,
marker=theta_true_marker,
label='Theta True'
)
else:
self._plot_2d(subplots, plot_index, include_axis_ticks)

subplots[1, -1].legend()
subplots[0, -1].legend()

subplots[1, 0].set_ylabel("True Parameters")
subplots[0, 0].set_ylabel("Predicted Parameters")
Expand Down
4 changes: 3 additions & 1 deletion src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"prior": "normal",
"prior_kwargs": None,
"simulator_kwargs": None,
"simulator_dimensions": 1,
},
"plots_common": {
"axis_spines": False,
Expand All @@ -29,7 +30,8 @@
"TARP": {
"coverage_sigma": 3 # How many sigma to show coverage over
},
"LC2ST": {}
"LC2ST": {},
"PPC":{}
},
"metrics_common": {
"use_progress_bar": False,
Expand Down
6 changes: 4 additions & 2 deletions src/utils/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ def register_simulator(simulator_name, simulator):
if not os.path.exists(sim_paths):
open(sim_paths, 'a').close()

with open(sim_paths, "w+") as f:
with open(sim_paths, "r+") as f:
try:
existing_sims = json.load(f)
except json.decoder.JSONDecodeError:
existing_sims = {}


existing_sims[simulator_name] = simulator_location
with open(sim_paths, "w") as f:
existing_sims[simulator_name] = simulator_location
json.dump(existing_sims, f)

Expand Down
Loading