Skip to content

Commit

Permalink
Merge pull request #81 from deepskies/simulation_warning
Browse files Browse the repository at this point in the history
Add warning that simulator is missing
  • Loading branch information
voetberg authored Jun 27, 2024
2 parents 4fd4f48 + d30587f commit f489e74
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 12 deletions.
15 changes: 11 additions & 4 deletions src/deepdiagnostics/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from deepdiagnostics.models import ModelModules
from deepdiagnostics.metrics import Metrics
from deepdiagnostics.plots import Plots
from deepdiagnostics.utils.simulator_utils import SimulatorMissingError


def parser():
Expand Down Expand Up @@ -109,9 +110,15 @@ def main():
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, save=True)(**metrics_args)
try:
Metrics[metrics_name](model, data, save=True)(**metrics_args)
except SimulatorMissingError:
print(f"Cannot run {metrics_name} - simulator missing.")

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
try:
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
except SimulatorMissingError:
print(f"Cannot run {plot_name} - simulator missing.")
8 changes: 6 additions & 2 deletions src/deepdiagnostics/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from deepdiagnostics.utils.config import get_item
from deepdiagnostics.utils.register import load_simulator
from deepdiagnostics.utils.simulator_utils import load_simulator

class Data:
"""
Expand Down Expand Up @@ -35,7 +35,11 @@ def __init__(
get_item("common", "random_seed", raise_exception=False)
)
self.data = self._load(path)
self.simulator = load_simulator(simulator_name, simulator_kwargs)
try:
self.simulator = load_simulator(simulator_name, simulator_kwargs)
except RuntimeError:
print("Warning: Simulator not loaded. Can only run non-generative metrics.")

self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.get_theta_true().shape[1]
self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False)
Expand Down
3 changes: 1 addition & 2 deletions src/deepdiagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ def void2(*args, **kwargs):
return None
return void2


Metrics = {
"": void,
"": void,
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LC2ST
Expand Down
3 changes: 3 additions & 0 deletions src/deepdiagnostics/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sklearn.utils import shuffle

from deepdiagnostics.metrics.metric import Metric
from deepdiagnostics.utils.simulator_utils import SimulatorMissingError

class LocalTwoSampleTest(Metric):
"""
Expand Down Expand Up @@ -46,6 +47,8 @@ def __init__(
percentiles,
number_simulations
)
if not hasattr(self.data, "simulator"):
raise SimulatorMissingError("Missing a simulator to run LC2ST.")

def _collect_data_params(self):
# P is the prior and x_P is generated via the simulator from the parameters P.
Expand Down
2 changes: 1 addition & 1 deletion src/deepdiagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from deepdiagnostics.plots.parity import Parity
from deepdiagnostics.plots.predictive_prior_check import PriorPC


def void(*args, **kwargs):
def void2(*args, **kwargs):
return None
return void2


Plots = {
"": void,
CDFRanks.__name__: CDFRanks,
Expand Down
3 changes: 3 additions & 0 deletions src/deepdiagnostics/plots/predictive_posterior_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from deepdiagnostics.plots.plot import Display
from deepdiagnostics.utils.plotting_utils import get_hex_colors
from deepdiagnostics.utils.simulator_utils import SimulatorMissingError

class PPC(Display):
"""
Expand Down Expand Up @@ -33,6 +34,8 @@ def __init__(
colorway =None):

super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)
if not hasattr(self.data, "simulator"):
raise SimulatorMissingError("Missing a simulator to run PPC.")

def plot_name(self):
return "predictive_posterior_check.png"
Expand Down
5 changes: 4 additions & 1 deletion src/deepdiagnostics/plots/predictive_prior_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from deepdiagnostics.plots.plot import Display
from deepdiagnostics.utils.simulator_utils import SimulatorMissingError

class PriorPC(Display):
"""
Expand Down Expand Up @@ -36,7 +37,9 @@ def __init__(
colorway = None):

super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)

if not hasattr(self.data, "simulator"):
raise SimulatorMissingError("Missing a simulator to run PriorPC.")

if self.data.simulator_dimensions == 1:
self.plot_image = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,8 @@ def load_simulator(name, simulator_kwargs):
"Simulator improperly formed - requires a simulate method."
)

return simulator_instance
return simulator_instance


class SimulatorMissingError(Exception):
pass
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from deepdiagnostics.data.simulator import Simulator
from deepdiagnostics.models import SBIModel
from deepdiagnostics.utils.config import get_item
from deepdiagnostics.utils.register import register_simulator
from deepdiagnostics.utils.simulator_utils import register_simulator


class MockSimulator(Simulator):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,25 @@ def test_main_missing_args(model_path):
process = subprocess.run(command)
exit_code = process.returncode
assert exit_code == 1


def test_missing_simulator(model_path, data_path):
command = [
"diagnose",
"--model_path",
model_path,
"--data_path",
data_path,
"--simulator",
"Not_A_Registered_Name",
"--plots",
"PPC",
"--metrics",
""
]
process = subprocess.run(command, capture_output=True)
exit_code = process.returncode
stdout = process.stdout.decode("utf-8")
assert exit_code == 0
plot_name = "PPC"
assert f"Cannot run {plot_name} - simulator missing." in stdout

0 comments on commit f489e74

Please sign in to comment.