diff --git a/docs/source/API/client.rst b/docs/source/API/client.rst deleted file mode 100644 index 6124f9e..0000000 --- a/docs/source/API/client.rst +++ /dev/null @@ -1,5 +0,0 @@ -Client -======== - -.. automodule:: client - :members: \ No newline at end of file diff --git a/docs/source/API/metrics.rst b/docs/source/API/metrics.rst deleted file mode 100644 index e32e273..0000000 --- a/docs/source/API/metrics.rst +++ /dev/null @@ -1,5 +0,0 @@ -Metrics -=========== - -.. autoclass:: metrics.metric.Metric - :members: \ No newline at end of file diff --git a/docs/source/API/plots.rst b/docs/source/API/plots.rst deleted file mode 100644 index b47991c..0000000 --- a/docs/source/API/plots.rst +++ /dev/null @@ -1,6 +0,0 @@ -Plots -======== - -.. autoclass:: plots.Plots - :members: - diff --git a/docs/source/API/utils.rst b/docs/source/API/utils.rst deleted file mode 100644 index cc92a44..0000000 --- a/docs/source/API/utils.rst +++ /dev/null @@ -1,6 +0,0 @@ -Utils -======= - -.. autoclass:: utils.config.Config - :members: - diff --git a/docs/source/client.rst b/docs/source/client.rst new file mode 100644 index 0000000..a7cbf56 --- /dev/null +++ b/docs/source/client.rst @@ -0,0 +1,35 @@ +Client +======== + +.. note:: + When running the client, you can supply **either** the configuration yaml file, or the CLI arguments. + You do not need to supply both. + +Use the command `diagnose -h` to view all usage of the CLI helper at any time. +Specific argument descriptions and explanations can be found on the :ref:`configuration` page. + +.. code-block:: bash + + usage: diagnose [-h] [--config CONFIG] [--model_path MODEL_PATH] [--model_engine {SBIModel}] [--data_path DATA_PATH] [--data_engine {H5Data,PickleData}] + [--simulator SIMULATOR] [--out_dir OUT_DIR] [--metrics [{CoverageFraction,AllSBC,LC2ST}]] + [--plots [{CDFRanks,CoverageFraction,Ranks,TARP,LC2ST,PPC}]] + + options: + -h, --help show this help message and exit + --config CONFIG, -c CONFIG + .yaml file with all arguments to run. + --model_path MODEL_PATH, -m MODEL_PATH + String path to a model. Must be compatible with your model_engine choice. + --model_engine {SBIModel}, -e {SBIModel} + Way to load your model. See each module's documentation page for requirements and specifications. + --data_path DATA_PATH, -d DATA_PATH + String path to data. Must be compatible with data_engine choice. + --data_engine {H5Data,PickleData}, -g {H5Data,PickleData} + Way to load your data. See each module's documentation page for requirements and specifications. + --simulator SIMULATOR, -s SIMULATOR + String name of the simulator to use with generative metrics and plots. Must be pre-register with the `utils.register_simulator` method. + --out_dir OUT_DIR Where the results will be saved. Path need not exist, it will be created. + --metrics [{CoverageFraction,AllSBC,LC2ST}] + List of metrics to run. To not run any, supply `--metrics ` + --plots [{CDFRanks,CoverageFraction,Ranks,TARP,LC2ST,PPC}] + List of plots to run. To not run any, supply `--plots ` \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 4f0cba7..55adeea 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,6 +1,6 @@ import sys -sys.path.append("../src") +sys.path.append("../src/deepdiagnostics") # Configuration file for the Sphinx documentation builder. # @@ -10,7 +10,7 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = "DeepDiagnostics" +project = "deepdiagnostics" copyright = "2024, Becky Nevin, M Voetberg, Brian Nord" author = "Becky Nevin, M Voetberg, Brian Nord" release = "0.1.0" @@ -23,17 +23,19 @@ "sphinx.ext.autosummary", "sphinx.ext.napoleon", "sphinx_autodoc_typehints", + 'sphinxcontrib.bibtex' ] +bibtex_bibfiles = ['ref.bib'] napoleon_use_param = True autodoc_default_options = { "members": True, } autodoc_typehints = "description" - +autoclass_content = "class" templates_path = ["_templates"] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = "alabaster" +html_theme = "pyramid" html_static_path = ["_static"] diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index 84e9c63..cb9bae0 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -1,2 +1,167 @@ +.. _configuration: + Configuration -=============== \ No newline at end of file +=============== + +The configuration file is a `.yaml` file that makes up the majority of the settings. +It is specified by the user, and if a field is not set, it falls back to a set of pre-defined defaults. +It is split into sections to easily organize different parameters. + +Specifying the Configuration +----------------------------- + +* Pipeline Mode + +To run diagnostics via the command line, pass the path to a yaml file to the diagnostic command. +This will run the entire set of diagnostics according to the configuration file. + +.. code-block:: bash + + diagnose --config path/to/your/config.yaml + +* Standalone Mode +The configuration file is not strictly required for running in standalone mode, +but it can be specified to quickly access variables to avoid re-writing initialization parameters or ensure repeatability. + +.. code-block:: python + + from deepdiagnostics.utils.configuration import Config + + + Config("path/to/your/config.yaml") + + +Configuration Description +----------------------- + +.. attribute:: common + + :param out_dir: Folder where the results of program are saved. The path need not exist, it will be created if it does not. + + :param temp_config: Path to a yaml to store a temporary config. Used only if some arguments are specified outside the config (eg, if using both the --config and --model_path arguments) + + :param sim_location: Path to store settings for simulations. When using the register_simulator method, this is where the registered simulations are catalogued. + + :param random_seed: Integer random seed to use. + +.. code-block:: yaml + + common: + out_dir: "./DeepDiagnosticsResources/results/" + temp_config: "./DeepDiagnosticsResources/temp/temp_config.yml" + sim_location: "DeepDiagnosticsResources/simulators" + random_seed: 42 + +.. attribute:: model + + :param model_path: Path to stored model. Required. + + :param model_engine: Loading method to use. Choose from methods listed in :ref:`models`. + +.. code-block:: yaml + + model: + model_path: {No Default} + model_engine: "SBIModel" + +.. attribute:: data + + :param data_path: Path to stored data. Required. + + :param data_engine: Loading method to use. Choose from methods listed in :ref:`data`. + + :param simulator: String name of the simulator. Must be pre-registered . + + :param prior: Prior distribution used in training. Used if "prior" is not included in the passed data. + + :param prior_kwargs: kwargs to use with the initialization of the prior + + :param simulator_kwargs: kwargs to use with the initialization of the simulation + + :param simulator_dimensions: If the output of the simulation is 1D (non-image) or 2D (images.) + +.. code-block:: yaml + + data: + data_path: {No Default} + data_engine: "H5Data" + prior: "normal" + prior_kwargs: {No Default} + simulator_kwargs: {No Default} + simulator_dimensions: 1 + +.. attribute:: plots_common + + :param axis_spines: Show axis ticks + + :param tight_layout: Minimize the space between axes and labels + + :param default_colorway: String colorway to use. Choose from `matplotlib's named colorways `_. + + :param plot_style: Style sheet. Choose form `matplotlib's style sheets `_. + + :param parameter_labels: Name of each theta parameter to use for titling and labels. Corresponding with the dim=1 axis of theta given by data. + + :param parameter_colors: Colors to use for each theta parameters when representing the parameters on the same plot. + + :param line_style_cycle: Line styles that can be used (besides for solid lines, which are always used.) + + :param figure_size: Default size for square figures. Will be adapted (slightly expanded) for multi-plot figures. + +.. code-block:: yaml + + plots_common: + axis_spines: False + tight_layout: True + default_colorway: "viridis" + plot_style: "fast" + parameter_labels: ["$m$", "$b$"] + parameter_colors: ["#9C92A3", "#0F5257"] + line_style_cycle: ["-", "-."] + figure_size: [6, 6] + +.. attribute:: metrics_common + + These parameters are used for every metric calculated, and for plots that require new inference to be run. + + :param use_progress_bar: Show a progress bar when iteratively performing inference. + + :param samples_per_inference: Number of samples used in a single iteration of inference. + + :param percentiles: List of integer percentiles, for defining coverage regions. + + :param number_simulations: Number of different simulations to run. Often, this means that the number of inferences performed for a metric is samples_per_inference*number_simulations + +.. code-block:: yaml + + metrics_common: + use_progress_bar: False + samples_per_inference: 1000 + percentiles: [75, 85, 95] + number_simulations: 50 + + +.. attribute:: plots + + A dictionary of different plots to generate and their arguments. + Can be any of the implemented plots listed in :ref:`plots` + If the plots are specified with an empty dictionary, defaults from the class are used. + Defaults: ["CDFRanks", "Ranks", "CoverageFraction", "TARP", "LC2ST", "PPC"] + +.. code-block:: yaml + + plots: + TARP: {} + + +.. attribute:: metrics + + A dictionary of different metrics to generate and their arguments. + Can be any of the implemented plots listed in :ref:`metrics` + If the metrics are specified with an empty dictionary, defaults from the class are used. + Defaults: ["AllSBC", "CoverageFraction", "LC2ST"] + +.. code-block:: yaml + + metrics: + LC2ST: {} diff --git a/docs/source/API/data.rst b/docs/source/data.rst similarity index 62% rename from docs/source/API/data.rst rename to docs/source/data.rst index 9e66449..c13d14e 100644 --- a/docs/source/API/data.rst +++ b/docs/source/data.rst @@ -1,3 +1,5 @@ +.. _data: + Data ====== @@ -8,4 +10,7 @@ Data :members: .. autoclass:: data.PickleData - :members: \ No newline at end of file + :members: + +.. autoclass:: data.simulator.Simulator + :members: \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 7b2cbc9..138025d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,12 +14,9 @@ Welcome to DeepDiagnostics's documentation! configuration plots metrics - API/client - API/utils - API/data - API/models - API/plots - API/metrics + client + data + models Indices and tables ================== diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 3335acc..12f4241 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -1,4 +1,20 @@ .. _metrics: Metrics -========= \ No newline at end of file +========= + +.. autoclass:: metrics.metric.Metric + :members: + +.. autoclass:: metrics.AllSBC + :members: calculate + +.. autoclass:: metrics.LC2ST + +.. autoclass:: metrics.local_two_sample.LocalTwoSampleTest + :members: calculate + +.. autoclass:: metrics.CoverageFraction + :members: calculate + +.. bibliography:: \ No newline at end of file diff --git a/docs/source/API/models.rst b/docs/source/models.rst similarity index 89% rename from docs/source/API/models.rst rename to docs/source/models.rst index b732091..e4a124d 100644 --- a/docs/source/API/models.rst +++ b/docs/source/models.rst @@ -1,3 +1,5 @@ +.. _models: + Models ======== diff --git a/docs/source/plots.rst b/docs/source/plots.rst index 6b60a3b..d0dd599 100644 --- a/docs/source/plots.rst +++ b/docs/source/plots.rst @@ -1,4 +1,33 @@ .. _plots: Plots -======= \ No newline at end of file +======= + +.. autoclass:: plots.plot.Display + +.. autoclass:: plots.CDFRanks + :members: plot + +.. autoclass:: plots.Ranks + :members: plot + +.. autoclass:: plots.CoverageFraction + :members: plot + +.. autoclass:: plots.TARP + :members: plot + +.. autoclass:: plots.LC2ST +.. autoclass:: plots.local_two_sample.LocalTwoSampleTest + :members: plot + +.. autoclass:: plots.PPC + :members: plot + +.. autoclass:: plots.PriorPC + :members: plot + +.. autoclass:: plots.Parity + :members: plot + +.. bibliography:: \ No newline at end of file diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 8bd263a..31f98c9 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -1,35 +1,49 @@ Quickstart ============ +Notebook Example +----------------- + +`An example notebook can be found here for an interactive walkthrough `_. + Installation -------------- * From PyPi -.. code-block:: console - pip install DeepDiagnostics +.. code-block:: bash + + pip install deepdiagnostics * From Source -.. code-block:: console +.. code-block:: bash + git clone https://github.com/deepskies/DeepDiagnostics/ pip install poetry poetry shell poetry install +Configuration +---- + +Description of the configuration file, including defaults, can be found in :ref:`configuration`. + Pipeline --------- `DeepDiagnostics` includes a CLI tool for analysis. * To run the tool using a configuration file: -.. code-block:: console +.. code-block:: bash + diagnose --config {path to yaml} * To use defaults with specific models and data: -.. code-block:: console +.. code-block:: bash + diagnose --model_path {model pkl} --data_path {data pkl} [--simulator {sim name}] @@ -44,17 +58,62 @@ This requires setting a configuration file ahead of time, and then running the p All plots and metrics can be found in :ref:`plots` and :ref:`metrics`. - .. code-block:: python - from DeepDiagnostics.utils.configuration import Config - from DeepDiagnostics.model import SBIModel - from DeepDiagnostics.data import H5Data + from deepdiagnostics.utils.configuration import Config + from deepdiagnostics.model import SBIModel + from deepdiagnostics.data import H5Data - from DeepDiagnostics.plots import ... + from deepdiagnostics.plots import LocalTwoSampleTest, Ranks Config({configuration_path}) model = SBIModel({model_path}) - data = H5Data({data_path}) + data = H5Data({data_path}, simulator={simulator name}) + + LocalTwoSampleTest(data=data, model=model, show=True)(use_intensity_plot=False, n_alpha_samples=200) + Ranks(data=data, model=model, show=True)(num_bins=3) + + +Custom Simulations +------------------- + +To use generative model diagnostics, a simulator has to be included. +This is done by `registering` your simulation with a name and a class associated. + +By doing this, the DeepDiagnostics can find your simulation at a later time and the simulation does not need to be loaded in memory at time of running the CLI pipeline or standalone modules. + +.. code-block:: python + + from deepdiagnostics.utils.register import register_simulator + + class MySimulation: + def __init__(...) + ... + + + register_simulator(simulator_name="MySimulation", simulator=MySimulation) + + +Simulations also require two different methods - `generate_context` (Which is used to either load or generate the non-theta input parameter for the simulation, also called `x`) and `simulate`. +This is enforced by using the abstract class `deepdiagnostics.data.Simulator` as a parent class. + +.. code-block:: python + + from deepdiagnostics.data import Simulator + + import numpy as np + + + class MySimulation(Simulator): + def generate_context(self, n_samples: int) -> np.ndarray: + """Give a number of samples (int) and get a numpy array of random samples to be used for the simulation""" + return np.random.uniform(0, 1) + + def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: + """Give the parameters of the simulation (theta), and x values (context_samples) and get a simulation sample. + Theta and context should have the same shape for dimension 0, the number of samples.""" + simulation_results = np.zeros(theta.shape[0], 1) + for index, context in enumerate(context_samples): + simulation_results[index] = theta[index][0]*context + theta[index][1]*context - {Plot of choice} + return simulation_results diff --git a/docs/source/ref.bib b/docs/source/ref.bib new file mode 100644 index 0000000..de75b7c --- /dev/null +++ b/docs/source/ref.bib @@ -0,0 +1,30 @@ +@misc{lemos2023samplingbased, + title={Sampling-Based Accuracy Testing of Posterior Estimators for General Inference}, + author={Pablo Lemos and Adam Coogan and Yashar Hezaveh and Laurence Perreault-Levasseur}, + year={2023}, + eprint={2302.03026}, + archivePrefix={arXiv}, + primaryClass={id='stat.ML' full_name='Machine Learning' is_active=True alt_name=None in_archive='stat' is_general=False description='Covers machine learning papers (supervised, unsupervised, semi-supervised learning, graphical models, reinforcement learning, bandits, high dimensional inference, etc.) with a statistical or theoretical grounding'} +} + +@misc{linhart2023lc2st, + title={L-C2ST: Local Diagnostics for Posterior Approximations in Simulation-Based Inference}, + author={Julia Linhart and Alexandre Gramfort and Pedro L. C. Rodrigues}, + year={2023}, + eprint={2306.03580}, + archivePrefix={arXiv}, + primaryClass={id='stat.ML' full_name='Machine Learning' is_active=True alt_name=None in_archive='stat' is_general=False description='Covers machine learning papers (supervised, unsupervised, semi-supervised learning, graphical models, reinforcement learning, bandits, high dimensional inference, etc.) with a statistical or theoretical grounding'} +} + +@article{centero2020sbi, + doi = {10.21105/joss.02505}, + url = {https://doi.org/10.21105/joss.02505}, + year = {2020}, + publisher = {The Open Journal}, + volume = {5}, + number = {52}, + pages = {2505}, + author = {Alvaro Tejero-Cantero and Jan Boelts and Michael Deistler and Jan-Matthis Lueckmann and Conor Durkan and Pedro J. Gonçalves and David S. Greenberg and Jakob H. Macke}, + title = {sbi: A toolkit for simulation-based inference}, + journal = {Journal of Open Source Software} +} \ No newline at end of file diff --git a/src/data/__init__.py b/src/data/__init__.py deleted file mode 100644 index b61257c..0000000 --- a/src/data/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from data.h5_data import H5Data -from data.pickle_data import PickleData - -DataModules = {"H5Data": H5Data, "PickleData": PickleData} diff --git a/src/data/data.py b/src/data/data.py deleted file mode 100644 index 0dcdef1..0000000 --- a/src/data/data.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Optional -import numpy as np - -from utils.config import get_item -from utils.register import load_simulator - -class Data: - def __init__( - self, - path: str, - simulator_name: str, - simulator_kwargs: dict = None, - prior: str = None, - prior_kwargs: dict = None, - simulation_dimensions:Optional[int] = None, - ): - self.rng = np.random.default_rng( - get_item("common", "random_seed", raise_exception=False) - ) - self.data = self._load(path) - self.simulator = load_simulator(simulator_name, simulator_kwargs) - self.prior_dist = self.load_prior(prior, prior_kwargs) - self.n_dims = self.get_theta_true().shape[1] - self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False) - - def get_simulator_output_shape(self): - context_shape = self.true_context().shape - sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1]) - return sim_out.shape - - def _load(self, path: str): - raise NotImplementedError - - def true_context(self): - # From Data - raise NotImplementedError - - def true_simulator_outcome(self): - return self.simulator(self.get_theta_true(), self.true_context()) - - def sample_prior(self, n_samples: int): - return self.prior_dist(size=(n_samples, self.n_dims)) - - def simulator_outcome(self, theta, condition_context=None, n_samples=None): - if condition_context is None: - if n_samples is None: - raise ValueError( - "Samples required if condition context is not specified" - ) - return self.simulator(theta, n_samples) - else: - return self.simulator.simulate(theta, condition_context) - - def simulated_context(self, n_samples): - return self.simulator.generate_context(n_samples) - - def get_theta_true(self): - if hasattr(self, "theta_true"): - return self.theta_true - else: - return get_item("data", "theta_true", raise_exception=True) - - def get_sigma_true(self): - if hasattr(self, "sigma_true"): - return self.sigma_true() - else: - return get_item("data", "sigma_true", raise_exception=True) - - def save(self, data, path: str): - raise NotImplementedError - - def read_prior(self): - raise NotImplementedError - - def load_prior(self, prior, prior_kwargs): - if prior is None: - prior = get_item("data", "prior", raise_exception=False) - try: - prior = self.read_prior() - except NotImplementedError: - choices = { - "normal": self.rng.normal, - "poisson": self.rng.poisson, - "uniform": self.rng.uniform, - "gamma": self.rng.gamma, - "beta": self.rng.beta, - "binominal": self.rng.binomial, - } - - if prior not in choices.keys(): - raise NotImplementedError( - f"{prior} is not an option for a prior, choose from {list(choices.keys())}" - ) - if prior_kwargs is None: - prior_kwargs = {} - return lambda size: choices[prior](**prior_kwargs, size=size) - - except KeyError as e: - raise RuntimeError(f"Data missing a prior specification - {e}") diff --git a/src/data/h5_data.py b/src/data/h5_data.py deleted file mode 100644 index 8d6dd17..0000000 --- a/src/data/h5_data.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Any, Callable, Optional -import h5py -import numpy as np -import torch -import os - -from data.data import Data - - -class H5Data(Data): - def __init__(self, - path: str, - simulator: Callable, - simulator_kwargs: dict = None, - prior: str = None, - prior_kwargs: dict = None, - simulation_dimensions:Optional[int] = None, - ): - super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions) - - def _load(self, path): - assert path.split(".")[-1] == "h5", "File extension must be h5" - loaded_data = {} - with h5py.File(path, "r") as file: - for key in file.keys(): - loaded_data[key] = torch.Tensor(file[key][...]) - return loaded_data - - def save(self, data: dict[str, Any], path: str): # Todo typing for data dict - assert path.split(".")[-1] == "h5", "File extension must be h5" - if not os.path.exists(os.path.dirname(path)): - os.makedirs(os.path.dirname(path)) - - data_arrays = {key: np.asarray(value) for key, value in data.items()} - with h5py.File(path, "w") as file: - # Save each array as a dataset in the HDF5 file - for key, value in data_arrays.items(): - file.create_dataset(key, data=value) - - def true_context(self): - # From Data - return self.data["xs"] # TODO change name - - def prior(self): - # From Data - raise NotImplementedError - - def get_theta_true(self): - return self.data["thetas"] - - def get_sigma_true(self): - try: - return super().get_sigma_true() - except (AssertionError, KeyError): - return 1 diff --git a/src/data/pickle_data.py b/src/data/pickle_data.py deleted file mode 100644 index ea2fa9d..0000000 --- a/src/data/pickle_data.py +++ /dev/null @@ -1,26 +0,0 @@ -import pickle -from typing import Any, Callable, Optional -from data.data import Data - - -class PickleData(Data): - def __init__(self, - path: str, - simulator: Callable, - simulator_kwargs: dict = None, - prior: str = None, - prior_kwargs: dict = None, - simulation_dimensions:Optional[int] = None, - ): - super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions) - - def _load(self, path: str): - assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" - with open(path, "rb") as file: - data = pickle.load(file) - return data - - def save(self, data: Any, path: str): - assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" - with open(path, "wb") as file: - pickle.dump(data, file) diff --git a/src/deepdiagnostics/__init__.py b/src/deepdiagnostics/__init__.py new file mode 100644 index 0000000..683091b --- /dev/null +++ b/src/deepdiagnostics/__init__.py @@ -0,0 +1 @@ +version = "0.1.0" \ No newline at end of file diff --git a/src/__init__.py b/src/deepdiagnostics/client/__init__.py similarity index 100% rename from src/__init__.py rename to src/deepdiagnostics/client/__init__.py diff --git a/src/client/client.py b/src/deepdiagnostics/client/client.py similarity index 65% rename from src/client/client.py rename to src/deepdiagnostics/client/client.py index 147bdf4..59c8024 100644 --- a/src/client/client.py +++ b/src/deepdiagnostics/client/client.py @@ -2,38 +2,47 @@ import yaml from argparse import ArgumentParser -from utils.config import Config -from utils.defaults import Defaults -from data import DataModules -from models import ModelModules -from metrics import Metrics -from plots import Plots +from deepdiagnostics.utils.config import Config +from deepdiagnostics.utils.defaults import Defaults +from deepdiagnostics.data import DataModules +from deepdiagnostics.models import ModelModules +from deepdiagnostics.metrics import Metrics +from deepdiagnostics.plots import Plots def parser(): parser = ArgumentParser() - parser.add_argument("--config", "-c", default=None) + parser.add_argument("--config", "-c", default=None, help=".yaml file with all arguments to run.") # Model - parser.add_argument("--model_path", "-m", default=None) + parser.add_argument("--model_path", "-m", default=None, help="String path to a model. Must be compatible with your model_engine choice.") parser.add_argument( "--model_engine", "-e", default=Defaults["model"]["model_engine"], choices=ModelModules.keys(), + help="Way to load your model. See each module's documentation page for requirements and specifications." ) # Data - parser.add_argument("--data_path", "-d", default=None) + parser.add_argument("--data_path", "-d", default=None, help="String path to data. Must be compatible with data_engine choice.") parser.add_argument( "--data_engine", "-g", default=Defaults["data"]["data_engine"], choices=DataModules.keys(), + help="Way to load your data. See each module's documentation page for requirements and specifications." ) - parser.add_argument("--simulator", "-s", default=None) + parser.add_argument( + "--simulator", "-s", + default=None, + help='String name of the simulator to use with generative metrics and plots. Must be pre-register with the `utils.register_simulator` method.') # Common - parser.add_argument("--out_dir", default=Defaults["common"]["out_dir"]) + parser.add_argument( + "--out_dir", + default=Defaults["common"]["out_dir"], + help="Where the results will be saved. Path need not exist, it will be created." + ) # List of metrics (cannot supply specific kwargs) parser.add_argument( @@ -41,6 +50,7 @@ def parser(): nargs="?", default=list(Defaults["metrics"].keys()), choices=Metrics.keys(), + help="List of metrics to run. To not run any, supply `--metrics `" ) # List of plots @@ -49,6 +59,8 @@ def parser(): nargs="?", default=list(Defaults["plots"].keys()), choices=Plots.keys(), + help="List of plots to run. To not run any, supply `--plots `" + ) args = parser.parse_args() diff --git a/src/deepdiagnostics/data/__init__.py b/src/deepdiagnostics/data/__init__.py new file mode 100644 index 0000000..43a04a9 --- /dev/null +++ b/src/deepdiagnostics/data/__init__.py @@ -0,0 +1,4 @@ +from deepdiagnostics.data.h5_data import H5Data +from deepdiagnostics.data.pickle_data import PickleData + +DataModules = {"H5Data": H5Data, "PickleData": PickleData} diff --git a/src/deepdiagnostics/data/data.py b/src/deepdiagnostics/data/data.py new file mode 100644 index 0000000..3b9539f --- /dev/null +++ b/src/deepdiagnostics/data/data.py @@ -0,0 +1,193 @@ +from typing import Any, Optional, Sequence, Union +import numpy as np + +from deepdiagnostics.utils.config import get_item +from deepdiagnostics.utils.register import load_simulator + +class Data: + """ + Load stored data to use in diagnostics + + Args: + path (str): path to the data file. + simulator_name (str): Name of the register simulator. If your simulator is not registered with utils.register_simulator, it will produce an error here. + simulator_kwargs (dict, optional): Any additional kwargs used set up your simulator. Defaults to None. + prior (str, optional): If the prior is not given in the data, use a numpy random distribution. Specified by name. Choose from: { + "normal" + "poisson" + "uniform" + "gamma" + "beta" + "binominal}. Defaults to None. + prior_kwargs (dict, optional): kwargs for the numpy prior. `View this page for a description `_. Defaults to None. + simulation_dimensions (Optional[int], optional): 1 or 2. 1->output of the simulator has one dimensions, 2->output has two dimensions (is an image). Defaults to None. + """ + def __init__( + self, + path: str, + simulator_name: str, + simulator_kwargs: dict = None, + prior: str = None, + prior_kwargs: dict = None, + simulation_dimensions:Optional[int] = None, + ): + self.rng = np.random.default_rng( + get_item("common", "random_seed", raise_exception=False) + ) + self.data = self._load(path) + self.simulator = load_simulator(simulator_name, simulator_kwargs) + self.prior_dist = self.load_prior(prior, prior_kwargs) + self.n_dims = self.get_theta_true().shape[1] + self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False) + + def get_simulator_output_shape(self) -> tuple[Sequence[int]]: + """ + Run a single sample of the simulator to verify the out-shape. + + Returns: + tuple[Sequence[int]]: Output shape of a single sample of the simulator. + """ + context_shape = self.true_context().shape + sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1]) + return sim_out.shape + + def _load(self, path: str): + raise NotImplementedError + + def true_context(self): + """ + True data x values, if supplied by the data method. + """ + # From Data + raise NotImplementedError + + def true_simulator_outcome(self) -> np.ndarray: + """ + Run the simulator on all true theta and true x values. + + Returns: + np.ndarray: array of (n samples, simulator shape) showing output of the simulator on all true samples in data. + """ + return self.simulator(self.get_theta_true(), self.true_context()) + + def sample_prior(self, n_samples: int) -> np.ndarray: + """ + Draw samples from the simulator + + Args: + n_samples (int): Number of samples to draw + + Returns: + np.ndarray: + """ + return self.prior_dist(size=(n_samples, self.n_dims)) + + def simulator_outcome(self, theta:np.ndarray, condition_context:np.ndarray=None, n_samples:int=None): + """_summary_ + + Args: + theta (np.ndarray): Theta value of shape (n_samples, theta_dimensions) + condition_context (np.ndarray, optional): If x values for theta are known, use them. Defaults to None. + n_samples (int, optional): If x values are not known for theta, draw them randomly. Defaults to None. + + Raises: + ValueError: If either n samples or content samples is supplied. + + Returns: + np.ndarray: Simulator output of shape (n samples, simulator_dimensions) + """ + if condition_context is None: + if n_samples is None: + raise ValueError( + "Samples required if condition context is not specified" + ) + return self.simulator(theta, n_samples) + else: + return self.simulator.simulate(theta, condition_context) + + def simulated_context(self, n_samples:int) -> np.ndarray: + """ + Call the simulator's `generate_context` method. + + Args: + n_samples (int): Number of samples to draw. + + Returns: + np.ndarray: context (x values), as defined by the simulator. + """ + return self.simulator.generate_context(n_samples) + + def get_theta_true(self) -> Union[Any, float, int, np.ndarray]: + """ + Look for the true theta given by data. If supplied in the method, use that, other look in the configuration file. + If neither are supplied, return None. + + Returns: + Any: Theta value selected by the search. + """ + if hasattr(self, "theta_true"): + return self.theta_true + else: + return get_item("data", "theta_true", raise_exception=True) + + def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]: + """ + Look for the true sigma of data. If supplied in the method, use that, other look in the configuration file. + If neither are supplied, return 1. + + Returns: + Any: Sigma value selected by the search. + """ + if hasattr(self, "sigma_true"): + return self.sigma_true() + else: + return get_item("data", "sigma_true", raise_exception=True) + + def save(self, data, path: str): + raise NotImplementedError + + def read_prior(self): + raise NotImplementedError + + def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable: + """ + Load the prior. + Either try to get it from data (if it has been implemented for the type of data), + or use numpy to initialize a random distribution using the prior argument. + + Args: + prior (str): Name of prior. + prior_kwargs (dict[str, any]): kwargs for initializing the prior. + + Raises: + NotImplementedError: The selected prior is not included. + RuntimeError: The selected prior is missing arguments to initialize. + + Returns: + callable: Prior that can be sampled from by calling it with prior(n_samples) + """ + + if prior is None: + prior = get_item("data", "prior", raise_exception=False) + try: + prior = self.read_prior() + except NotImplementedError: + choices = { + "normal": self.rng.normal, + "poisson": self.rng.poisson, + "uniform": self.rng.uniform, + "gamma": self.rng.gamma, + "beta": self.rng.beta, + "binominal": self.rng.binomial, + } + + if prior not in choices.keys(): + raise NotImplementedError( + f"{prior} is not an option for a prior, choose from {list(choices.keys())}" + ) + if prior_kwargs is None: + prior_kwargs = {} + return lambda size: choices[prior](**prior_kwargs, size=size) + + except KeyError as e: + raise RuntimeError(f"Data missing a prior specification - {e}") diff --git a/src/deepdiagnostics/data/h5_data.py b/src/deepdiagnostics/data/h5_data.py new file mode 100644 index 0000000..c5e50ff --- /dev/null +++ b/src/deepdiagnostics/data/h5_data.py @@ -0,0 +1,100 @@ +from typing import Any +import h5py +import numpy as np +import torch +import os + +from deepdiagnostics.data.data import Data + + +class H5Data(Data): + """ + Load data that has been saved in a h5 format. + + .. attribute:: Data Parameters + + :xs: [REQUIRED] The context, the x values. The data that was used to train a model on what conditions produce what posterior. + :thetas: [REQUIRED] The theta, the parameters of the external model. The data used to train the model's posterior. + :prior: Distribution used to initialize the posterior before training. + :sigma: True standard deviation of the actual thetas, if known. + + """ + + def __init__(self, + path, + simulator, + simulator_kwargs = None, + prior=None, + prior_kwargs = None, + simulation_dimensions = None, + ): + super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions) + + def _load(self, path): + assert path.split(".")[-1] == "h5", "File extension must be h5" + loaded_data = {} + with h5py.File(path, "r") as file: + for key in file.keys(): + loaded_data[key] = torch.Tensor(file[key][...]) + return loaded_data + + def save(self, data: dict[str, Any], path: str): # Todo typing for data dict + assert path.split(".")[-1] == "h5", "File extension must be h5" + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + + data_arrays = {key: np.asarray(value) for key, value in data.items()} + with h5py.File(path, "w") as file: + # Save each array as a dataset in the HDF5 file + for key, value in data_arrays.items(): + file.create_dataset(key, data=value) + + def true_context(self): + """ + Try to get the `xs` field of the loaded data. + + Raises: + NotImplementedError: The data does not have a `xs` field. + """ + try: + return self.data["xs"] + except KeyError: + raise NotImplementedError("Cannot find `xs` in data. Please supply it.") + + def prior(self): + """ + If the data has a supplied prior, return it. If not, the data module will default back to picking a prior from a random distribution. + + Raises: + NotImplementedError: The data does not have a `prior` field. + """ + try: + return self.data['prior'] + except KeyError: + raise NotImplementedError("Data does not have a `prior` field.") + + def get_theta_true(self): + """ Get stored theta used to train the model. + + Returns: + theta array + + Raise: + NotImplementedError: Data does not have thetas. + """ + try: + return self.data["thetas"] + except KeyError: + raise NotImplementedError("Data does not have a `thetas` field.") + + def get_sigma_true(self): + """ + Try to get the true standard deviation of the data. If it is not supplied, return 1. + + Returns: + Any: sigma. + """ + try: + return super().get_sigma_true() + except (AssertionError, KeyError): + return 1 diff --git a/src/deepdiagnostics/data/pickle_data.py b/src/deepdiagnostics/data/pickle_data.py new file mode 100644 index 0000000..0659c02 --- /dev/null +++ b/src/deepdiagnostics/data/pickle_data.py @@ -0,0 +1,36 @@ +import pickle +from typing import Any +from deepdiagnostics.data.data import Data + + +class PickleData(Data): + """ + Load data that is saved as a .pkl file. + """ + def __init__(self, + path, + simulator, + simulator_kwargs = None, + prior = None, + prior_kwargs = None, + simulation_dimensions = None, + ): + super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions) + + def _load(self, path: str) -> Any: + assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" + with open(path, "rb") as file: + data = pickle.load(file) + return data + + def save(self, data: Any, path: str) -> None: + """ + Save data in the form of a .pkl file. + + Args: + data (Any): Data that can be encoded into a pkl. + path (str): Out file path for the data. Must have a .pkl extension. + """ + assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" + with open(path, "wb") as file: + pickle.dump(data, file) diff --git a/src/data/simulator.py b/src/deepdiagnostics/data/simulator.py similarity index 57% rename from src/data/simulator.py rename to src/deepdiagnostics/data/simulator.py index 7274086..6b77d77 100644 --- a/src/data/simulator.py +++ b/src/deepdiagnostics/data/simulator.py @@ -14,9 +14,27 @@ def __init__(self) -> None: def generate_context(self, n_samples: int) -> np.ndarray: """ [ABSTRACT, MUST BE FILLED] + Specify how the conditioning context is generated. Can come from data, or from a generic distribution. + Example: + + .. code-block:: python + + # Generate from a random distribution + class MySim(Simulator): + def generate_context(self, n_samples: int) -> np.ndarray: + return np.random.uniform(0, 1) + + # Draw from a sample + class MySim(Simulator): + def __init__(self): + self.data_source = ..... + + def generate_context(self, n_samples: int) -> np.ndarray: + return self.data_source.sample(n_samples) + Args: n_samples (int): Number of samples of context to pull @@ -29,8 +47,21 @@ def generate_context(self, n_samples: int) -> np.ndarray: def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: """ [ABSTRACT, MUST BE FILLED] + Specify a simulation S such that y_{theta} = S(context_samples|theta) + Example: + .. code-block:: python + + # Generate from a random distribution + class MySim(Simulator): + def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: + simulation_results = np.zeros(theta.shape[0], 1) + for index, context in enumerate(context_samples): + simulation_results[index] = theta[index][0]*context + theta[index][1]*context + + return simulation_results + Args: theta (np.ndarray): Parameters of the simulation model context_samples (np.ndarray): Samples to use with the theta-primed simulation model diff --git a/src/deepdiagnostics/metrics/__init__.py b/src/deepdiagnostics/metrics/__init__.py new file mode 100644 index 0000000..18dc3df --- /dev/null +++ b/src/deepdiagnostics/metrics/__init__.py @@ -0,0 +1,10 @@ +from deepdiagnostics.metrics.all_sbc import AllSBC +from deepdiagnostics.metrics.coverage_fraction import CoverageFraction +from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST + +Metrics = { + "": lambda **kwargs: None, + CoverageFraction.__name__: CoverageFraction, + AllSBC.__name__: AllSBC, + "LC2ST": LC2ST +} diff --git a/src/metrics/all_sbc.py b/src/deepdiagnostics/metrics/all_sbc.py similarity index 51% rename from src/metrics/all_sbc.py rename to src/deepdiagnostics/metrics/all_sbc.py index e33a230..7f3f6f2 100644 --- a/src/metrics/all_sbc.py +++ b/src/deepdiagnostics/metrics/all_sbc.py @@ -1,22 +1,33 @@ -from typing import Any, Optional, Sequence +from typing import Any, Sequence from torch import tensor from sbi.analysis import run_sbc, check_sbc -from metrics.metric import Metric -from utils.config import get_item +from deepdiagnostics.metrics.metric import Metric class AllSBC(Metric): + """ + Calculate SBC diagnostics metrics and add them to the output. + Adapted from :cite:p:`centero2020sbi`. + More information about specific metrics can be found `here `_. + + .. code-block:: python + + from deepdiagnostics.metrics import AllSBC + + metrics = AllSBC(model, data, save=False)() + metrics = metrics.output + """ def __init__( self, - model: Any, - data: Any, - out_dir: str | None = None, - save: bool=True, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - percentiles: Optional[Sequence[int]] = None, - number_simulations: Optional[int] = None, + model, + data, + out_dir= None, + save = True, + use_progress_bar = None, + samples_per_inference = None, + percentiles = None, + number_simulations = None, ) -> None: super().__init__(model, data, out_dir, @@ -30,7 +41,13 @@ def _collect_data_params(self): self.thetas = tensor(self.data.get_theta_true()) self.context = tensor(self.data.true_context()) - def calculate(self): + def calculate(self) -> dict[str, Sequence]: + """ + Calculate all SBC diagnostic metrics + + Returns: + dict[str, Sequence]: Dictionary with all calculations, labeled by their name. + """ ranks, dap_samples = run_sbc( self.thetas, self.context, diff --git a/src/metrics/coverage_fraction.py b/src/deepdiagnostics/metrics/coverage_fraction.py similarity index 56% rename from src/metrics/coverage_fraction.py rename to src/deepdiagnostics/metrics/coverage_fraction.py index 24c494a..9109c98 100644 --- a/src/metrics/coverage_fraction.py +++ b/src/deepdiagnostics/metrics/coverage_fraction.py @@ -1,25 +1,30 @@ import numpy as np -from torch import tensor from tqdm import tqdm -from typing import Any, Optional, Sequence - -from metrics.metric import Metric -from utils.config import get_item +from typing import Any, Sequence +from deepdiagnostics.metrics.metric import Metric class CoverageFraction(Metric): - """ """ + """ + Calculate the coverage of a set number of inferences over different confidence regions. + + .. code-block:: python + + from deepdiagnostics.metrics import CoverageFraction + + samples, coverage = CoverageFraction(model, data, save=False).calculate() + """ def __init__( self, - model: Any, - data: Any, - out_dir: Optional[str] = None, - save: bool=True, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - percentiles: Optional[Sequence[int]] = None, - number_simulations: Optional[int] = None, + model, + data, + out_dir= None, + save=True, + use_progress_bar = None, + samples_per_inference = None, + percentiles = None, + number_simulations = None, ) -> None: super().__init__(model, data, out_dir, @@ -36,23 +41,31 @@ def _collect_data_params(self): def _run_model_inference(self, samples_per_inference, y_inference): samples = self.model.sample_posterior(samples_per_inference, y_inference) - return samples + return samples.numpy() + + def calculate(self) -> tuple[Sequence, Sequence]: + """ + Calculate the coverage fraction of the given model and data - def calculate(self): + Returns: + tuple[Sequence, Sequence]: A tuple of the samples tested (M samples, Samples per inference, N parameters) and the coverage over those samples. + """ all_samples = np.empty( - (len(self.context), self.samples_per_inference, np.shape(self.thetas)[1]) + (self.number_simulations, self.samples_per_inference, np.shape(self.thetas)[1]) ) count_array = [] - iterator = enumerate(self.context) + iterator = range(self.number_simulations) if self.use_progress_bar: iterator = tqdm( iterator, desc="Sampling from the posterior for each observation", unit=" observation", ) - for y_sample_index, y_sample in iterator: - samples = self._run_model_inference(self.samples_per_inference, y_sample) - all_samples[y_sample_index] = samples + for sample_index in iterator: + context_sample = self.context[self.data.rng.integers(0, len(self.context))] + samples = self._run_model_inference(self.samples_per_inference, context_sample) + + all_samples[sample_index] = samples count_vector = [] # step through the percentile list @@ -63,12 +76,10 @@ def calculate(self): # find the percentile for the posterior for this observation # this is n_params dimensional # the units are in parameter space - confidence_lower = tensor( - np.percentile(samples.cpu(), percentile_lower, axis=0) - ) - confidence_upper = tensor( - np.percentile(samples.cpu(), percentile_upper, axis=0) - ) + confidence_lower = np.percentile(samples, percentile_lower, axis=0) + + confidence_upper = np.percentile(samples, percentile_upper, axis=0) + # this is asking if the true parameter value # is contained between the @@ -76,10 +87,11 @@ def calculate(self): # checks separately for each side of the 50th percentile count = np.logical_and( - confidence_upper - self.thetas[y_sample_index, :] > 0, - self.thetas[y_sample_index, :] - confidence_lower > 0, + confidence_upper - self.thetas[sample_index, :].numpy() > 0, + self.thetas[sample_index, :].numpy() - confidence_lower > 0, ) count_vector.append(count) + # each time the above is > 0, adds a count count_array.append(count_vector) diff --git a/src/metrics/local_two_sample.py b/src/deepdiagnostics/metrics/local_two_sample.py similarity index 80% rename from src/metrics/local_two_sample.py rename to src/deepdiagnostics/metrics/local_two_sample.py index 8a563fb..c6e7e50 100644 --- a/src/metrics/local_two_sample.py +++ b/src/deepdiagnostics/metrics/local_two_sample.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence, Union +from typing import Any, Union import numpy as np @@ -6,19 +6,34 @@ from sklearn.neural_network import MLPClassifier from sklearn.utils import shuffle -from metrics.metric import Metric +from deepdiagnostics.metrics.metric import Metric class LocalTwoSampleTest(Metric): + """ + Adapted from :cite:p:`linhart2023lc2st`. + Train a classifier to verify the quality of the posterior via classifier accuracy. + Produces an array of inference accuracies for the trained classier, representing the cases of either denying the null hypothesis + (that the posterior output of the simulation is not significantly different from a given random sample.) + + Code referenced from: + `github.com/JuliaLinhart/lc2st/lc2st.py::train_lc2st `_. + + .. code-block:: python + + from deepdiagnostics.metrics import LC2ST + + true_probabilities, null_hypothesis_probabilities = LC2ST(model, data, save=False).calculate() + """ def __init__( self, - model: Any, - data: Any, - out_dir: Optional[str] = None, - save: bool=True, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - percentiles: Optional[Sequence[int]] = None, - number_simulations: Optional[int] = None, + model, + data, + out_dir = None, + save=True, + use_progress_bar = None, + samples_per_inference = None, + percentiles = None, + number_simulations = None, ) -> None: super().__init__( @@ -32,7 +47,6 @@ def __init__( number_simulations ) - def _collect_data_params(self): # P is the prior and x_P is generated via the simulator from the parameters P. self.p = self.data.sample_prior(self.number_simulations) @@ -74,7 +88,6 @@ def _collect_data_params(self): self.outcome_given_q[index] = q_outcome.ravel() # Q is the approximate posterior amortized in x - def train_linear_classifier( self, p, q, x_p, x_q, classifier: str, classifier_kwargs: dict = {} ): @@ -188,7 +201,7 @@ def _cross_eval_score( ) return probabilities - def permute_data(self, P, Q): + def _permute_data(self, P, Q): """Permute the concatenated data [P,Q] to create null-hyp samples. Args: @@ -206,7 +219,22 @@ def calculate( cross_evaluate: bool = True, n_null_hypothesis_trials=100, classifier_kwargs: Union[dict, list[dict]] = None, - ): + ) -> tuple[np.ndarray, np.ndarray]: + """ + Perform the calculation for the LC2ST. + Adds the results to the lc2st.output (dict) under the parameters + "lc2st_probabilities", "lc2st_null_hypothesis_probabilities" as lists. + + Args: + linear_classifier (Union[str, list[str]], optional): linear classifier to use for the test. Only MLP is implemented. Defaults to "MLP". + cross_evaluate (bool, optional): Use a k-fold'd dataset for evaluation. Defaults to True. + n_null_hypothesis_trials (int, optional): Number of times to draw and test the null hypothesis. Defaults to 100. + classifier_kwargs (Union[dict, list[dict]], optional): Additional kwargs for the linear classifier. Defaults to None. + + Returns: + tuple[np.ndarray, np.ndarray]: arrays storing the true and null hypothesis probabilities given the linear classifier. + + """ if isinstance(linear_classifier, str): linear_classifier = [linear_classifier] @@ -228,7 +256,7 @@ def calculate( for _ in range(n_null_hypothesis_trials): joint_P_x = np.concatenate([self.p, self.outcome_given_p], axis=1) joint_Q_x = np.concatenate([self.q, self.outcome_given_q], axis=1) - joint_P_x_perm, joint_Q_x_perm = self.permute_data( + joint_P_x_perm, joint_Q_x_perm = self._permute_data( joint_P_x, joint_Q_x, ) diff --git a/src/metrics/metric.py b/src/deepdiagnostics/metrics/metric.py similarity index 64% rename from src/metrics/metric.py rename to src/deepdiagnostics/metrics/metric.py index b92413b..42ba170 100644 --- a/src/metrics/metric.py +++ b/src/deepdiagnostics/metrics/metric.py @@ -2,12 +2,27 @@ import json import os -from data import data -from models import model -from utils.config import get_item +from deepdiagnostics.data import data +from deepdiagnostics.models import model +from deepdiagnostics.utils.config import get_item class Metric: + """ + These parameters are used for every metric calculated, and for plots that require new inference to be run. + Calculate a given metric. Save output to a json if out_dir and saving specified. + + Args: + model (deepdiagnostics.models.model): Model to calculate the metric for. Required. + data (deepdiagnostics.data.data): Data to test against. Required. + out_dir (Optional[str], optional): Directory to save a json (results.json) to. Defaults to None. + save (bool, optional): Save the output to json. Defaults to True. + use_progress_bar (Optional[bool], optional):Show a progress bar when iteratively performing inference. Defaults to None. + samples_per_inference (Optional[int], optional) :Number of samples used in a single iteration of inference. Defaults to None. + percentiles (Optional[Sequence[int]], optional): List of integer percentiles, for defining coverage regions. Defaults to None. + number_simulations (Optional[int], optional):Number of different simulations to run. Often, this means that the number of inferences performed for a metric is samples_per_inference*number_simulations. Defaults to None. + """ + def __init__( self, model: model, diff --git a/src/deepdiagnostics/models/__init__.py b/src/deepdiagnostics/models/__init__.py new file mode 100644 index 0000000..29ee5f3 --- /dev/null +++ b/src/deepdiagnostics/models/__init__.py @@ -0,0 +1,3 @@ +from deepdiagnostics.models.sbi_model import SBIModel + +ModelModules = {"SBIModel": SBIModel} diff --git a/src/models/model.py b/src/deepdiagnostics/models/model.py similarity index 75% rename from src/models/model.py rename to src/deepdiagnostics/models/model.py index 254d584..46e2c81 100644 --- a/src/models/model.py +++ b/src/deepdiagnostics/models/model.py @@ -1,4 +1,10 @@ class Model: + """ + Load a pre-trained model for analysis. + + Args: + model_path (str): relative path to a model. + """ def __init__(self, model_path: str) -> None: self.model = self._load(model_path) diff --git a/src/deepdiagnostics/models/sbi_model.py b/src/deepdiagnostics/models/sbi_model.py new file mode 100644 index 0000000..f062698 --- /dev/null +++ b/src/deepdiagnostics/models/sbi_model.py @@ -0,0 +1,56 @@ +import os +import pickle + +from deepdiagnostics.models.model import Model + + +class SBIModel(Model): + """ + Load a trained model that was generated with Mackelab SBI :cite:p:`centero2020sbi`. + `Read more about saving and loading requirements here `_. + + Args: + model_path (str): relative path to a model - must be a .pkl file. + """ + def __init__(self, model_path): + super().__init__(model_path) + + def _load(self, path: str) -> None: + assert os.path.exists(path), f"Cannot find model file at location {path}" + assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" + + with open(path, "rb") as file: + posterior = pickle.load(file) + self.posterior = posterior + + def sample_posterior(self, n_samples: int, x_true): + """ + Sample the posterior + + Args: + n_samples (int): Number of samples to draw + x_true (np.ndarray): Context samples. (must be dims=(n_samples, M)) + + Returns: + np.ndarray: Posterior samples + """ + return self.posterior.sample( + (n_samples,), x=x_true, show_progress_bars=False + ).cpu() # TODO Unbind from cpu + + def predict_posterior(self, data, context_samples): + """ + Sample the posterior and then + + Args: + data (deepdiagnostics.data.Data): Data module with the loaded simulation + context_samples (np.ndarray): X values to test the posterior over. + + Returns: + np.ndarray: Simulator output + """ + posterior_samples = self.sample_posterior(context_samples) + posterior_predictive_samples = data.simulator( + posterior_samples, context_samples + ) + return posterior_predictive_samples diff --git a/src/deepdiagnostics/plots/__init__.py b/src/deepdiagnostics/plots/__init__.py new file mode 100644 index 0000000..a63b51c --- /dev/null +++ b/src/deepdiagnostics/plots/__init__.py @@ -0,0 +1,20 @@ +from deepdiagnostics.plots.cdf_ranks import CDFRanks +from deepdiagnostics.plots.coverage_fraction import CoverageFraction +from deepdiagnostics.plots.ranks import Ranks +from deepdiagnostics.plots.tarp import TARP +from deepdiagnostics.plots.local_two_sample import LocalTwoSampleTest as LC2ST +from deepdiagnostics.plots.predictive_posterior_check import PPC +from deepdiagnostics.plots.parity import Parity +from deepdiagnostics.plots.predictive_prior_check import PriorPC + +Plots = { + "": lambda **kwargs: None, + CDFRanks.__name__: CDFRanks, + CoverageFraction.__name__: CoverageFraction, + Ranks.__name__: Ranks, + TARP.__name__: TARP, + "LC2ST": LC2ST, + PPC.__name__: PPC, + "Parity": Parity, + PriorPC.__name__: PriorPC +} diff --git a/src/deepdiagnostics/plots/cdf_ranks.py b/src/deepdiagnostics/plots/cdf_ranks.py new file mode 100644 index 0000000..99827a0 --- /dev/null +++ b/src/deepdiagnostics/plots/cdf_ranks.py @@ -0,0 +1,63 @@ +from sbi.analysis import sbc_rank_plot, run_sbc +from torch import tensor + +from deepdiagnostics.plots.plot import Display + + +class CDFRanks(Display): + def __init__( + self, + model, + data, + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway =None): + + """ + Adaptation of :cite:p:`centero2020sbi`. + A wrapper around `SBI `_'s sbc_rank_plot function. + `More information can be found here `_ + Plots the ranks as a CDF plot for each theta parameter. + + .. code-block:: python + + from deepdiagnostics.plots import CDFRanks + + CDFRanks(model, data, save=False, show=True)() + + """ + + super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) + + def plot_name(self): + return "cdf_ranks.png" + + def _data_setup(self): + thetas = tensor(self.data.get_theta_true()) + context = tensor(self.data.true_context()) + + ranks, _ = run_sbc( + thetas, context, self.model.posterior, num_posterior_samples=self.samples_per_inference + ) + self.ranks = ranks + + def plot_settings(self): + pass + + def plot(self): + """ + """ + sbc_rank_plot( + self.ranks, + self.samples_per_inference, + plot_type="cdf", + parameter_labels=self.parameter_names, + colors=self.parameter_colors, + ) diff --git a/src/plots/coverage_fraction.py b/src/deepdiagnostics/plots/coverage_fraction.py similarity index 56% rename from src/plots/coverage_fraction.py rename to src/deepdiagnostics/plots/coverage_fraction.py index b8eb8a4..a3e91f6 100644 --- a/src/plots/coverage_fraction.py +++ b/src/deepdiagnostics/plots/coverage_fraction.py @@ -1,35 +1,47 @@ -from typing import Optional, Sequence import numpy as np import matplotlib.pyplot as plt -from metrics.coverage_fraction import CoverageFraction as coverage_fraction_metric -from plots.plot import Display -from utils.config import get_item +from deepdiagnostics.metrics.coverage_fraction import CoverageFraction as coverage_fraction_metric +from deepdiagnostics.plots.plot import Display +from deepdiagnostics.utils.config import get_item class CoverageFraction(Display): + """ + Show posterior regions of confidence as a function of percentiles. + Each parameter of theta is plotted against a coverage fraction for each given theta. + + .. code-block:: python + + from deepdiagnostics.plots import CoverageFraction + + CoverageFraction(model, data, show=True, save=False)( + figure_alpha=0.8, + legend_loc="upper left", + reference_line_label="Ideal" + ) + """ def __init__( self, model, data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - number_simulations: Optional[int] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway =None): super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) self.n_parameters = len(self.parameter_names) self.line_cycle = tuple(get_item("plots_common", "line_style_cycle", raise_exception=False)) - def _plot_name(self): + def plot_name(self): return "coverage_fraction.png" def _data_setup(self): @@ -38,7 +50,7 @@ def _data_setup(self): ).calculate() self.coverage_fractions = coverage - def _plot( + def plot( self, figure_alpha=1.0, line_width=3, @@ -47,8 +59,19 @@ def _plot( reference_line_style="k--", x_label="Confidence Interval of the Posterior Volume", y_label="Fraction of Lenses within Posterior Volume", - title="NPE", - ): + title="NPE"): + """ + Args: + figure_alpha (float, optional): Opacity of parameter lines. Defaults to 1.0. + line_width (int, optional): Width of parameter lines. Defaults to 3. + legend_loc (str, optional): Location of the legend, str based on `matplotlib `_. Defaults to "lower right". + reference_line_label (str, optional): Label name for the diagonal ideal line. Defaults to "Reference Line". + reference_line_style (str, optional): Line style for the reference line. Defaults to "k--". + x_label (str, optional): y label. Defaults to "Confidence Interval of the Posterior Volume". + y_label (str, optional): y label. Defaults to "Fraction of Lenses within Posterior Volume". + title (str, optional): plot title. Defaults to "NPE". + """ + n_steps = self.coverage_fractions.shape[0] percentile_array = np.linspace(0, 1, n_steps) color_cycler = iter(plt.cycler("color", self.parameter_colors)) diff --git a/src/plots/local_two_sample.py b/src/deepdiagnostics/plots/local_two_sample.py similarity index 71% rename from src/plots/local_two_sample.py rename to src/deepdiagnostics/plots/local_two_sample.py index 79f0f4d..2a4234c 100644 --- a/src/plots/local_two_sample.py +++ b/src/deepdiagnostics/plots/local_two_sample.py @@ -1,15 +1,39 @@ -from typing import Optional, Sequence, Union +from typing import Union import matplotlib.pyplot as plt from matplotlib import cm import numpy as np from matplotlib.colors import Normalize from matplotlib.patches import Rectangle -from plots.plot import Display -from metrics.local_two_sample import LocalTwoSampleTest as l2st -from utils.plotting_utils import get_hex_colors +from deepdiagnostics.plots.plot import Display +from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as l2st +from deepdiagnostics.utils.plotting_utils import get_hex_colors class LocalTwoSampleTest(Display): + """ + Produce plots showing the local evaluation of a posterior estimator for a given observation. + Adapted fom Linhart et. al. :cite:p:`linhart2023lc2st`. + + Implements a pair plot, showing regions confidence regions of the CDF in comparison with the null hypothesis classifier results, + and an intensity plot, showing the regions of accuracy for each parameter of theta. + + Uses the following code as reference material: + + `github.com/JuliaLinhart/lc2st/graphical_diagnostics.py::pp_plot_lc2st `_. + + `github.com/JuliaLinhart/lc2st/graphical_diagnostics.py::eval_space_with_proba_intensity `_. + + + .. code-block:: python + + from deepdiagnostics.plots import LC2ST + + + """ + + # Plots to make - + # pp_plot_lc2st: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L49 + # eval_space_with_proba_intensity: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 # https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 @@ -17,23 +41,22 @@ def __init__( self, model, data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - number_simulations: Optional[int] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway =None): super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) self.region_colors = get_hex_colors(n_colors=len(self.percentiles), colorway=self.colorway) self.l2st = l2st(model, data, out_dir, True, self.use_progress_bar, self.samples_per_inference, self.percentiles, self.number_simulations) - def _plot_name(self): + def plot_name(self): return "local_C2ST.png" def _make_pairplot_values(self, random_samples): @@ -158,7 +181,7 @@ def probability_intensity(self, subplot, features, n_bins=20): ) subplot.add_patch(rect) - def _plot( + def plot( self, use_intensity_plot: bool = True, n_alpha_samples: int = 100, @@ -173,12 +196,29 @@ def _plot( pairplot_title="Local Classifier PP-Plot", intensity_plot_ylabel="", intensity_plot_xlabel="", - intensity_plot_title="Local Classifier Intensity Distribution", - ): + intensity_plot_title="Local Classifier Intensity Distribution"): + """ + Args: + use_intensity_plot (bool, optional): Use the additional intensity plots showing regions of prediction accuracy for different theta values. Defaults to True. + n_alpha_samples (int, optional): Number of samples to use to produce the cdf region. Defaults to 100. + confidence_region_alpha (float, optional): Opacity of the cdf region plots. Defaults to 0.2. + n_intensity_bins (int, optional): Number of bins to use when producing the intensity plots. Number of regions. Defaults to 20. + linear_classifier (Union[str, list[str]], optional): Type of linear classifiers to use. Only MLP is currently implemented. Defaults to "MLP". + cross_evaluate (bool, optional): Split the validation data in K folds to produce an uncertainty of the classification results. Defaults to True. + n_null_hypothesis_trials (int, optional): Number of inferences to classify under the null hypothesis. Defaults to 100. + classifier_kwargs (Union[dict, list[dict]], optional): Additional kwargs for the classifier. Depend on the classifier choice. Defaults to None. + pairplot_y_label (str, optional): Row label for the pairplot. Defaults to "Empirical CDF". + pairplot_x_label (str, optional): Column label for the pairplot. Defaults to "". + pairplot_title (str, optional): Title of the pair plot Defaults to "Local Classifier PP-Plot". + intensity_plot_ylabel (str, optional): Column label for the intensity plot. Defaults to "". + intensity_plot_xlabel (str, optional): Row label for the intensity plot. Defaults to "". + intensity_plot_title (str, optional): Title for the intensity plot. Defaults to "Local Classifier Intensity Distribution". + """ + # Plots to make - # pp_plot_lc2st: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L49 # eval_space_with_proba_intensity: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 - + self.l2st._collect_data_params() self.probability, self.null_hypothesis_probability = self.l2st.calculate( linear_classifier=linear_classifier, @@ -250,4 +290,4 @@ def _plot( self._finish() def __call__(self, **plot_args) -> None: - self._plot(**plot_args) + self.plot(**plot_args) diff --git a/src/plots/parity.py b/src/deepdiagnostics/plots/parity.py similarity index 71% rename from src/plots/parity.py rename to src/deepdiagnostics/plots/parity.py index b504756..92ad9d7 100644 --- a/src/plots/parity.py +++ b/src/deepdiagnostics/plots/parity.py @@ -1,28 +1,40 @@ -from typing import Optional, Sequence import matplotlib.pyplot as plt import numpy as np -from plots.plot import Display +from deepdiagnostics.plots.plot import Display class Parity(Display): + """ + Show plots directly comparing the posterior vs. true theta values. Make a plot that is (number of selected metrics) X dimensions of theta. + Includes the option to show differences, residual, and percent residual as plots under the main parity plot. + + .. code-block:: python + + from deepdiagnostics.plots import Parity + + Parity(model, data, show=True, save=False)( + n_samples=200 # 200 samples of the posterior + include_residual = True # Add a plot showing the residual + ) + """ def __init__( self, model, data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - number_simulations: Optional[int] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway =None): + super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) - def _plot_name(self): + def plot_name(self): return "parity.png" def get_posterior(self, n_samples): @@ -40,8 +52,7 @@ def get_posterior(self, n_samples): self.true_samples[index] = self.data.get_theta_true()[sample, :] - - def _plot( + def plot( self, n_samples: int = 80, include_difference: bool = False, @@ -53,6 +64,18 @@ def _plot( y_label:str=r"$\theta_{predicted}$", x_label:str=r"$\theta_{true}$" ): + """ + Args: + n_samples (int, optional): Samples to draw from the posterior for the plot. Defaults to 80. + include_difference (bool, optional): Include a plot that shows the difference between the posterior and true. Defaults to False. + include_residual (bool, optional): Include a plot that shows the residual between posterior and true. Defaults to False. + include_percentage (bool, optional): Include a plot that shows the residual as a percent between posterior and true. Defaults to False. + show_ideal (bool, optional): Include a line showing where posterior=true. Defaults to True. + errorbar_color (str, optional): _description_. Defaults to 'black'. + title (str, optional): Title of the plot. Defaults to "Parity". + y_label (str, optional): y axis label. Defaults to r"$\theta_{predicted}$". + x_label (str, optional): x axis label. Defaults to r"$\theta_{true}$". + """ self.get_posterior(n_samples) # parity - predicted vs true diff --git a/src/plots/plot.py b/src/deepdiagnostics/plots/plot.py similarity index 64% rename from src/plots/plot.py rename to src/deepdiagnostics/plots/plot.py index b3abf9e..26b81e3 100644 --- a/src/plots/plot.py +++ b/src/deepdiagnostics/plots/plot.py @@ -3,10 +3,28 @@ import matplotlib.pyplot as plt from matplotlib import rcParams -from utils.config import get_item +from deepdiagnostics.utils.config import get_item class Display: + """ + Parameters used against all plots. + + Args: + model (deepdiagnostics.models.model): Model to calculate the metric for. Required. + data (deepdiagnostics.data.data): Data to test against. Required. + out_dir (Optional[str], optional): Directory to save a png ({plot_name}.png) to. Defaults to None. + save (bool, optional): Save the output to png. + show (bool, optional): Show the completed plot when finished. + use_progress_bar (Optional[bool], optional):Show a progress bar when iteratively performing inference. Defaults to None. + samples_per_inference (Optional[int], optional) :Number of samples used in a single iteration of inference. Defaults to None. + percentiles (Optional[Sequence[int]], optional): List of integer percentiles, for defining coverage regions. Defaults to None. + number_simulations (Optional[int], optional):Number of different simulations to run. Often, this means that the number of inferences performed for a metric is samples_per_inference*number_simulations. Defaults to None. + parameter_names (Optional[Sequence], optional): Name of each theta parameter to use for titling and labels. Corresponding with the dim=1 axis of theta given by data. Defaults to None. + parameter_colors (Optional[Sequence], optional): Colors to use for each theta parameters when representing the parameters on the same plot. Defaults to None. + colorway (Optional[str], optional):String colorway to use. Choose from `matplotlib's named colorways `_. Defaults to None. + """ + def __init__( self, model, @@ -20,9 +38,8 @@ def __init__( number_simulations: Optional[int] = None, parameter_names: Optional[Sequence] = None, parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): - + colorway: Optional[str]=None): + self.save = save self.show = show self.data = data @@ -47,16 +64,16 @@ def __init__( self.model = model self._common_settings() - self.plot_name = self._plot_name() + self.plot_name = self.plot_name() - def _plot_name(self): + def plot_name(self): raise NotImplementedError def _data_setup(self): # Set all the vars used for the plot raise NotImplementedError - def _plot(self, **kwrgs): + def plot(self, **kwrgs): # Make the plot object with plt. raise NotImplementedError @@ -96,5 +113,5 @@ def __call__(self, **plot_args) -> None: except NotImplementedError: pass - self._plot(**plot_args) + self.plot(**plot_args) self._finish() diff --git a/src/plots/predictive_posterior_check.py b/src/deepdiagnostics/plots/predictive_posterior_check.py similarity index 75% rename from src/plots/predictive_posterior_check.py rename to src/deepdiagnostics/plots/predictive_posterior_check.py index 9f7c06e..09e4162 100644 --- a/src/plots/predictive_posterior_check.py +++ b/src/deepdiagnostics/plots/predictive_posterior_check.py @@ -1,30 +1,40 @@ -from typing import Optional, Sequence +from typing import Optional import matplotlib.pyplot as plt import numpy as np -from sklearn.model_selection import KFold -from plots.plot import Display -from utils.plotting_utils import get_hex_colors +from deepdiagnostics.plots.plot import Display +from deepdiagnostics.utils.plotting_utils import get_hex_colors class PPC(Display): + """ + Show the output of the model's generated posterior against the true values for the same context. + Can show either output vs input (in 1D) or examples of simulation output (in 2D). + + .. code-block:: python + + from deepdiagnostics.plots import PPC + + PPC(model, data, save=False, show=True)(n_unique_plots=5) # Plot 5 examples + + """ def __init__( self, model, data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - number_simulations: Optional[int] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway =None): + super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) - def _plot_name(self): + def plot_name(self): return "predictive_posterior_check.png" def get_posterior_2d(self, n_simulator_draws): @@ -80,7 +90,7 @@ def get_posterior_1d(self, n_simulator_draws): theta=self.data.get_theta_true()[sample, :], context_samples=context_sample ) - def _plot_1d(self, + def plot_1d(self, subplots: np.ndarray, subplot_index: int, n_coverage_sigma: Optional[int] = 3, @@ -123,7 +133,7 @@ def _plot_1d(self, label='Theta True' ) - def _plot_2d(self, subplots, subplot_index, include_axis_ticks): + def plot_2d(self, subplots, subplot_index, include_axis_ticks): subplots[1, subplot_index].imshow(self.posterior_predictive_samples[subplot_index]) subplots[0, subplot_index].imshow(self.posterior_true_samples[subplot_index]) @@ -134,7 +144,7 @@ def _plot_2d(self, subplots, subplot_index, include_axis_ticks): subplots[0, subplot_index].set_xticks([]) subplots[0, subplot_index].set_yticks([]) - def _plot( + def plot( self, n_coverage_sigma: Optional[int] = 3, true_sigma: Optional[float] = None, @@ -144,6 +154,20 @@ def _plot( title:str="Predictive Posterior", y_label:str="Simulation Output", x_label:str="X"): + """ + Args: + n_coverage_sigma (Optional[int], optional): Show the N different standard dev. sigma of the posterior results. Only used in 1D. Defaults to 3. + true_sigma (Optional[float], optional): True std. of the known posterior. Used only if supplied. Defaults to None. + theta_true_marker (Optional[str], optional): Marker to use for output of the true theta parameters. Only used in 1d. Defaults to '^'. + n_unique_plots (Optional[int], optional): Number of samples of theta/x to use. Each one corresponds to a column. Defaults to 3. + include_axis_ticks (bool, optional): _description_. Defaults to False. + title (str, optional): Title of the plot. Defaults to "Predictive Posterior". + y_label (str, optional): Row label. Defaults to "Simulation Output". + x_label (str, optional): Column label. Defaults to "X". + + Raises: + NotImplementedError: If trying to plot results of a simulation with more than 2 output dimensions. + """ if self.data.simulator_dimensions == 1: self.get_posterior_1d(n_unique_plots) @@ -166,10 +190,10 @@ def _plot( for plot_index in range(n_unique_plots): if self.data.simulator_dimensions == 1: - self._plot_1d(subplots, plot_index, n_coverage_sigma, theta_true_marker) + self.plot_1d(subplots, plot_index, n_coverage_sigma, theta_true_marker) else: - self._plot_2d(subplots, plot_index, include_axis_ticks) + self.plot_2d(subplots, plot_index, include_axis_ticks) subplots[1, 0].set_ylabel("True Parameters") diff --git a/src/plots/predictive_prior_check.py b/src/deepdiagnostics/plots/predictive_prior_check.py similarity index 54% rename from src/plots/predictive_prior_check.py rename to src/deepdiagnostics/plots/predictive_prior_check.py index 677fc08..bd528c6 100644 --- a/src/plots/predictive_prior_check.py +++ b/src/deepdiagnostics/plots/predictive_prior_check.py @@ -1,35 +1,67 @@ -from typing import Optional, Sequence +from typing import Optional import matplotlib.pyplot as plt import numpy as np -from plots.plot import Display +from deepdiagnostics.plots.plot import Display class PriorPC(Display): + """ + Plot random samples of the simulator's output from samples drawn from the prior + + .. code-block:: python + + from deepdiagnostics.plots import PriorPC + + PriorPC(model, data, show=True, save=False)( + n_rows = 2, + n_columns = 6, # Make 2x6 = 12 different samples + row_parameter_index = 0, + column_parameter_index = 1, # Include labels for theta parameters 0 and 1 from the prior + round_parameters = True, + ) + """ def __init__( self, model, data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - number_simulations: Optional[int] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway = None): + super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) + + if self.data.simulator_dimensions == 1: + self.plot_image = False - def _plot_name(self): + elif self.data.simulator_dimensions == 2: + self.plot_image = True + + + def plot_name(self): return "predictive_prior_check.png" def get_prior_samples(self, n_columns, n_rows): context_shape = self.data.true_context().shape - - self.prior_predictive_samples = np.zeros((n_rows, n_columns, context_shape[-1])) + remove_first_dim = False + + if self.plot_image: + sim_out_shape = self.data.get_simulator_output_shape() + if len(sim_out_shape) != 2: + # TODO Debug log with a warning + sim_out_shape = (sim_out_shape[1], sim_out_shape[2]) + remove_first_dim = True + self.prior_predictive_samples = np.zeros((n_rows, n_columns, *sim_out_shape)) + + else: + self.prior_predictive_samples = np.zeros((n_rows, n_columns, context_shape[-1])) self.prior_sample = np.zeros((n_rows, n_columns, self.data.n_dims)) self.context = np.zeros((n_rows, n_columns, context_shape[-1])) random_context_indices = self.data.rng.integers(0, context_shape[0], (n_rows, n_columns)) @@ -42,9 +74,13 @@ def get_prior_samples(self, n_columns, n_rows): prior_sample = self.data.sample_prior(1)[0] # get the posterior samples for that context - self.prior_predictive_samples[row_index, column_index] = self.data.simulator.simulate( + simulation_sample = self.data.simulator.simulate( theta=prior_sample, context_samples = context_sample ) + if remove_first_dim: + simulation_sample = simulation_sample[0] + + self.prior_predictive_samples[row_index, column_index] = simulation_sample self.prior_sample[row_index, column_index] = prior_sample self.context[row_index, column_index] = context_sample @@ -70,7 +106,7 @@ def fill_text(self, row_index, column_index, row_parameter_index, column_paramet raise ValueError(f"Cannot use {label_samples} to assign labels. Choose from 'both', 'rows', 'columns'.") - def _plot( + def plot( self, n_rows: Optional[int] = 3, n_columns: Optional[int] = 3, @@ -83,7 +119,21 @@ def _plot( title:Optional[str]="Simulated output from prior", y_label:Optional[str]=None, x_label:str=None): - + """ + + Args: + n_rows (Optional[int], optional): Number of unique rows to make for priors. Defaults to 3. + n_columns (Optional[int], optional): Number of unique columns for viewing prior predictions. Defaults to 3. + row_parameter_index (Optional[int], optional): Index of the theta parameter to display as labels on rows. Defaults to 0. + column_parameter_index (Optional[int], optional): Index of the theta parameter to display as labels on columns. Defaults to 1. + round_parameters (Optional[bool], optional): In labels, round the theta parameters (recommended when thetas are float values). Defaults to True. + sort_rows (bool, optional): Sort the plots by the theta row value. Defaults to True. + sort_columns (bool, optional): Sort the plots by theta column value. Defaults to True. + label_samples (Optional[str], optional): Label the prior values as a text box in each label. Row means using row_parameter_index as the title value. Choose from "rows", "columns", "both". Defaults to 'both'. + title (Optional[str], optional): Title of the whole plot. Defaults to "Simulated output from prior". + y_label (Optional[str], optional): Column label, when None, label = `theta_{column_index} = parameter_name`. Defaults to None. + x_label (str, optional): Row label, when None, label = `theta_{row_index} = parameter_name`. Defaults to None. + """ self.get_prior_samples(n_rows, n_columns) figure, subplots = plt.subplots( @@ -123,10 +173,16 @@ def _plot( ) subplots[plot_row_index, plot_column_index].title.set_text(text) - subplots[plot_row_index, plot_column_index].plot( - self.context[column_index, row_index], - self.prior_predictive_samples[column_index, row_index] - ) + if self.plot_image: + subplots[plot_row_index, plot_column_index].imshow(self.prior_predictive_samples[column_index, row_index]) + subplots[plot_row_index, plot_column_index].set_xticks([]) + subplots[plot_row_index, plot_column_index].set_yticks([]) + + else: + subplots[plot_row_index, plot_column_index].plot( + self.context[column_index, row_index], + self.prior_predictive_samples[column_index, row_index] + ) figure.supylabel(y_label) figure.supxlabel(x_label) diff --git a/src/deepdiagnostics/plots/ranks.py b/src/deepdiagnostics/plots/ranks.py new file mode 100644 index 0000000..0915066 --- /dev/null +++ b/src/deepdiagnostics/plots/ranks.py @@ -0,0 +1,62 @@ +from sbi.analysis import sbc_rank_plot, run_sbc +from torch import tensor + +from deepdiagnostics.plots.plot import Display + + +class Ranks(Display): + """ + + Adaptation of :cite:p:`centero2020sbi`. + + A wrapper around `SBI `_'s sbc_rank_plot function. + `More information can be found here `_ + Plots the histogram of each theta parameter's rank. + + .. code-block:: python + + from deepdiagnostics.plots import Ranks + + Ranks(model, data, save=False, show=True)(num_bins=25) + """ + def __init__( + self, + model, + data, + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway =None): + + super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) + + def plot_name(self): + return "ranks.png" + + def _data_setup(self): + thetas = tensor(self.data.get_theta_true()) + context = tensor(self.data.true_context()) + ranks, _ = run_sbc( + thetas, context, self.model.posterior, num_posterior_samples=self.samples_per_inference + ) + self.ranks = ranks + + def plot(self, num_bins:int=20): + """ + Args: + num_bins (int): Number of histogram bins. Defaults to 20. + """ + sbc_rank_plot( + ranks=self.ranks, + num_posterior_samples=self.samples_per_inference, + plot_type="hist", + num_bins=num_bins, + parameter_labels=self.parameter_names, + colors=self.parameter_colors, + ) diff --git a/src/plots/tarp.py b/src/deepdiagnostics/plots/tarp.py similarity index 63% rename from src/plots/tarp.py rename to src/deepdiagnostics/plots/tarp.py index 038d3f7..0523f5f 100644 --- a/src/plots/tarp.py +++ b/src/deepdiagnostics/plots/tarp.py @@ -1,35 +1,50 @@ -from typing import Optional, Sequence, Union +from typing import Union import numpy as np import tarp import matplotlib.pyplot as plt import matplotlib.colors as plt_colors -from plots.plot import Display -from utils.config import get_item +from deepdiagnostics.plots.plot import Display +from deepdiagnostics.utils.config import get_item class TARP(Display): + """ + Produce a TARP plot as described in Lemos et. al. :cite:p:`lemos2023samplingbased`. + Utilizes the implementation from `here `_. + + .. code-block:: python + + from deepdiagnostics.plots import TARP + + TARP(models, data, show=True, save=False)( + coverage_sigma=2, + coverage_alpha=0.4, + y_label="Credibility Level" + ) + + """ def __init__( self, model, data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - number_simulations: Optional[int] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway =None): + super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) self.line_style = get_item( "plots_common", "line_style_cycle", raise_exception=False ) - def _plot_name(self): + def plot_name(self): return "tarp.png" def _data_setup(self): @@ -51,7 +66,7 @@ def _data_setup(self): self.posterior_samples = np.swapaxes(self.posterior_samples, 0, 1) - def _plot_settings(self): + def plot_settings(self): self.line_style = get_item( "plots_common", "line_style_cycle", raise_exception=False ) @@ -66,19 +81,31 @@ def _get_hex_sigma_colors(self, n_colors): return hex_colors - def _plot( + def plot( self, coverage_sigma: int = 3, reference_point: Union[str, np.ndarray] = "random", metric: bool = "euclidean", normalize: bool = True, bootstrap_calculation: bool = True, - coverage_colorway: Optional[str] = None, coverage_alpha: float = 0.2, y_label: str = "Expected Coverage", x_label: str = "Expected Coverage", title: str = "Test of Accuracy with Random Points", ): + """ + Args: + coverage_sigma (int, optional): Number of sigma to use for coverage. Defaults to 3. + reference_point (Union[str, np.ndarray], optional): Reference points in the parameter space to test against. Defaults to "random". + metric (bool, optional): Distance metric ot use between reference points. Use "euclidean" or "manhattan".Defaults to "euclidean". + normalize (bool, optional): Normalize input space to 1. Defaults to True. + bootstrap_calculation (bool, optional): Estimate uncertainties using bootstrapped examples. Increases efficiency. Defaults to True. + coverage_alpha (float, optional): Opacity of the difference coverage sigma. Defaults to 0.2. + y_label (str, optional): Sup. label on the y axis. Defaults to "Expected Coverage". + x_label (str, optional): Sup. label on the x axis. Defaults to "Expected Coverage". + title (str, optional): Title of the entire figure. Defaults to "Test of Accuracy with Random Points". + + """ coverage_probability, credibility = tarp.get_tarp_coverage( self.posterior_samples, self.thetas, diff --git a/src/client/__init__.py b/src/deepdiagnostics/utils/__init__.py similarity index 100% rename from src/client/__init__.py rename to src/deepdiagnostics/utils/__init__.py diff --git a/src/utils/config.py b/src/deepdiagnostics/utils/config.py similarity index 97% rename from src/utils/config.py rename to src/deepdiagnostics/utils/config.py index 8bf2d9b..a07a170 100644 --- a/src/utils/config.py +++ b/src/deepdiagnostics/utils/config.py @@ -2,7 +2,7 @@ import os import yaml -from utils.defaults import Defaults +from deepdiagnostics.utils.defaults import Defaults def get_item(section, item, raise_exception=True): diff --git a/src/utils/defaults.py b/src/deepdiagnostics/utils/defaults.py similarity index 95% rename from src/utils/defaults.py rename to src/deepdiagnostics/utils/defaults.py index b132951..d9943ec 100644 --- a/src/utils/defaults.py +++ b/src/deepdiagnostics/utils/defaults.py @@ -2,7 +2,7 @@ "common": { "out_dir": "./DeepDiagnosticsResources/results/", "temp_config": "./DeepDiagnosticsResources/temp/temp_config.yml", - "sim_location": "DeepDiagnosticsResources/simulators", + "sim_location": "deepdiagnosticsResources/simulators", "random_seed": 42, }, "model": {"model_engine": "SBIModel"}, @@ -34,7 +34,6 @@ "Parity":{}, "PPC": {}, "PriorPC":{} - }, "metrics_common": { "use_progress_bar": False, diff --git a/src/utils/plotting_utils.py b/src/deepdiagnostics/utils/plotting_utils.py similarity index 100% rename from src/utils/plotting_utils.py rename to src/deepdiagnostics/utils/plotting_utils.py diff --git a/src/utils/register.py b/src/deepdiagnostics/utils/register.py similarity index 98% rename from src/utils/register.py rename to src/deepdiagnostics/utils/register.py index 60fdcfe..3021af9 100644 --- a/src/utils/register.py +++ b/src/deepdiagnostics/utils/register.py @@ -4,7 +4,7 @@ import sys import json -from utils.config import get_item +from deepdiagnostics.utils.config import get_item def register_simulator(simulator_name, simulator): diff --git a/src/utils/variable_store.py b/src/deepdiagnostics/utils/variable_store.py similarity index 100% rename from src/utils/variable_store.py rename to src/deepdiagnostics/utils/variable_store.py diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py deleted file mode 100644 index 9bb6fe9..0000000 --- a/src/metrics/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from metrics.all_sbc import AllSBC -from metrics.coverage_fraction import CoverageFraction -from metrics.local_two_sample import LocalTwoSampleTest - -Metrics = { - CoverageFraction.__name__: CoverageFraction, - AllSBC.__name__: AllSBC, - "LC2ST": LocalTwoSampleTest -} diff --git a/src/models/__init__.py b/src/models/__init__.py deleted file mode 100644 index 12b2264..0000000 --- a/src/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from models.sbi_model import SBIModel - -ModelModules = {"SBIModel": SBIModel} diff --git a/src/models/sbi_model.py b/src/models/sbi_model.py deleted file mode 100644 index 9085402..0000000 --- a/src/models/sbi_model.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import pickle - -from models.model import Model - - -class SBIModel(Model): - def __init__(self, model_path: str): - super().__init__(model_path) - - def _load(self, path: str) -> None: - assert os.path.exists(path), f"Cannot find model file at location {path}" - assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" - - with open(path, "rb") as file: - posterior = pickle.load(file) - self.posterior = posterior - - def sample_posterior(self, n_samples: int, y_true): # TODO typing - return self.posterior.sample( - (n_samples,), x=y_true, show_progress_bars=False - ).cpu() # TODO Unbind from cpu - - def predict_posterior(self, data): - posterior_samples = self.sample_posterior(data.y_true) - posterior_predictive_samples = data.simulator( - data.theta_true(), posterior_samples - ) - return posterior_predictive_samples diff --git a/src/plots/__init__.py b/src/plots/__init__.py deleted file mode 100644 index fb18498..0000000 --- a/src/plots/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from plots.cdf_ranks import CDFRanks -from plots.coverage_fraction import CoverageFraction -from plots.ranks import Ranks -from plots.tarp import TARP -from plots.local_two_sample import LocalTwoSampleTest -from plots.predictive_posterior_check import PPC -from plots.predictive_prior_check import PriorPC -from plots.parity import Parity -from plots.predictive_prior_check import PriorPC - -Plots = { - CDFRanks.__name__: CDFRanks, - CoverageFraction.__name__: CoverageFraction, - Ranks.__name__: Ranks, - TARP.__name__: TARP, - "LC2ST": LocalTwoSampleTest, - PPC.__name__: PPC, - "Parity": Parity, - PriorPC.__name__: PriorPC -} diff --git a/src/plots/cdf_ranks.py b/src/plots/cdf_ranks.py deleted file mode 100644 index 668b977..0000000 --- a/src/plots/cdf_ranks.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Optional, Sequence -from sbi.analysis import sbc_rank_plot, run_sbc -from torch import tensor - -from plots.plot import Display - - -class CDFRanks(Display): - def __init__( - self, - model, - data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - number_simulations: Optional[int] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): - - super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) - - def _plot_name(self): - return "cdf_ranks.png" - - def _data_setup(self): - thetas = tensor(self.data.get_theta_true()) - context = tensor(self.data.true_context()) - - ranks, _ = run_sbc( - thetas, context, self.model.posterior, num_posterior_samples=self.samples_per_inference - ) - self.ranks = ranks - - def _plot_settings(self): - pass - - def _plot(self): - sbc_rank_plot( - self.ranks, - self.samples_per_inference, - plot_type="cdf", - parameter_labels=self.parameter_names, - colors=self.parameter_colors, - ) diff --git a/src/plots/ranks.py b/src/plots/ranks.py deleted file mode 100644 index 4b9dc12..0000000 --- a/src/plots/ranks.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Optional, Sequence -from sbi.analysis import sbc_rank_plot, run_sbc -from torch import tensor - -from plots.plot import Display -from utils.config import get_item - - -class Ranks(Display): - def __init__( - self, - model, - data, - save:bool, - show:bool, - out_dir:Optional[str]=None, - percentiles: Optional[Sequence] = None, - use_progress_bar: Optional[bool] = None, - samples_per_inference: Optional[int] = None, - number_simulations: Optional[int] = None, - parameter_names: Optional[Sequence] = None, - parameter_colors: Optional[Sequence]= None, - colorway: Optional[str]=None - ): - - super().__init__(model, data, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) - - def _plot_name(self): - return "ranks.png" - - def _data_setup(self): - thetas = tensor(self.data.get_theta_true()) - context = tensor(self.data.true_context()) - ranks, _ = run_sbc( - thetas, context, self.model.posterior, num_posterior_samples=self.samples_per_inference - ) - self.ranks = ranks - - def _plot(self, num_bins=None): - sbc_rank_plot( - ranks=self.ranks, - num_posterior_samples=self.samples_per_inference, - plot_type="hist", - num_bins=num_bins, - parameter_labels=self.parameter_names, - colors=self.parameter_colors, - ) diff --git a/src/utils/__init__.py b/src/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py index 33f22db..672c9d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,11 +5,11 @@ import numpy as np from deepbench.astro_object import StarObject -from data import H5Data -from data.simulator import Simulator -from models import SBIModel -from utils.config import get_item -from utils.register import register_simulator +from deepdiagnostics.data import H5Data +from deepdiagnostics.data.simulator import Simulator +from deepdiagnostics.models import SBIModel +from deepdiagnostics.utils.config import get_item +from deepdiagnostics.utils.register import register_simulator class MockSimulator(Simulator): @@ -67,7 +67,7 @@ def simulate(self, theta, context_samples: np.ndarray): return np.array(generated_stars) @pytest.fixture(autouse=True) -def setUp(): +def setUp(result_output): register_simulator("MockSimulator", MockSimulator) register_simulator("Mock2DSimulator", Mock2DSimulator) yield @@ -76,10 +76,9 @@ def setUp(): sim_paths = f"{simulator_config_path.strip('/')}/simulators.json" os.remove(sim_paths) - out_dir = get_item("common", "out_dir", raise_exception=False) os.makedirs("resources/test_results/", exist_ok=True) - shutil.copytree(out_dir, "resources/test_results/", dirs_exist_ok=True) - shutil.rmtree(out_dir) + shutil.copytree(result_output, "resources/test_results/", dirs_exist_ok=True) + shutil.rmtree(result_output) @pytest.fixture def model_path(): @@ -92,7 +91,10 @@ def data_path(): @pytest.fixture def result_output(): - return "./temp_results/" + path = "./temp_results/" + if not os.path.exists(path): + os.makedirs(path) + return path @pytest.fixture def simulator_name(): diff --git a/tests/test_client.py b/tests/test_client.py index 7c987ae..d628afa 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,6 +11,10 @@ def test_parser_args(model_path, data_path, simulator_name): data_path, "--simulator", simulator_name, + "--metrics", + "", + "--plots", + "" ] process = subprocess.run(command) exit_code = process.returncode @@ -20,7 +24,7 @@ def test_parser_args(model_path, data_path, simulator_name): def test_parser_config(config_factory, model_path, data_path, simulator_name): config_path = config_factory( - model_path=model_path, data_path=data_path, simulator=simulator_name + model_path=model_path, data_path=data_path, simulator=simulator_name, metrics=[], plots=[] ) command = ["diagnose", "--config", config_path] process = subprocess.run(command) @@ -28,15 +32,13 @@ def test_parser_config(config_factory, model_path, data_path, simulator_name): assert exit_code == 0 -def test_main_no_methods(config_factory, model_path, data_path, simulator_name): - out_dir = "./test_out_dir/" +def test_main_no_methods(config_factory, model_path, data_path, simulator_name, result_output): config_path = config_factory( model_path=model_path, data_path=data_path, simulator=simulator_name, plots=[], - metrics=[], - out_dir=out_dir, + metrics=[] ) command = ["diagnose", "--config", config_path] process = subprocess.run(command) @@ -44,7 +46,7 @@ def test_main_no_methods(config_factory, model_path, data_path, simulator_name): assert exit_code == 0 # There should be nothing at the outpath - assert os.listdir(out_dir) == [] + assert os.listdir(result_output) == [] def test_main_missing_config(): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index eeda686..4c9b212 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,13 +1,11 @@ import os import pytest -from utils.defaults import Defaults -from utils.config import Config -from metrics import ( - Metrics, +from deepdiagnostics.utils.config import Config +from deepdiagnostics.metrics import ( CoverageFraction, AllSBC, - LocalTwoSampleTest + LC2ST ) @pytest.fixture @@ -17,41 +15,30 @@ def metric_config(config_factory): "samples_per_inference": 10, "percentiles": [95], } - config = config_factory(metrics_settings=metrics_settings) - Config(config) + return config_factory(metrics_settings=metrics_settings) - -def test_all_defaults(metric_config, mock_model, mock_data): - """ - Ensures each metric has a default set of parameters and is included in the defaults list - Ensures each test can initialize, regardless of the veracity of the output - """ - - for metric_name, metric_obj in Metrics.items(): - assert metric_name in Defaults["metrics"] - metric_obj(mock_model, mock_data) - - def test_coverage_fraction(metric_config, mock_model, mock_data): + Config(metric_config) coverage_fraction = CoverageFraction(mock_model, mock_data, save=True) _, coverage = coverage_fraction.calculate() assert coverage_fraction.output.all() is not None - # TODO Shape of coverage - assert coverage.shape + assert coverage.shape == (1, 2) # One percentile over 2 dimensions of theta. coverage_fraction = CoverageFraction(mock_model, mock_data, save=True) coverage_fraction() assert os.path.exists(f"{coverage_fraction.out_dir}/diagnostic_metrics.json") -def test_all_sbc(metric_config, mock_model, mock_data): +def test_all_sbc(metric_config, mock_model, mock_data): + Config(metric_config) all_sbc = AllSBC(mock_model, mock_data, save=True) all_sbc() assert all_sbc.output is not None assert os.path.exists(f"{all_sbc.out_dir}/diagnostic_metrics.json") -def test_lc2st(metric_config, mock_model, mock_data): - lc2st = LocalTwoSampleTest(mock_model, mock_data, save=True) +def test_lc2st(metric_config, mock_model, mock_data): + Config(metric_config) + lc2st = LC2ST(mock_model, mock_data, save=True) lc2st() assert lc2st.output is not None assert os.path.exists(f"{lc2st.out_dir}/diagnostic_metrics.json") diff --git a/tests/test_plots.py b/tests/test_plots.py index 161786c..1d074a3 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,16 +1,14 @@ import os import pytest -from utils.defaults import Defaults -from utils.config import Config, get_item +from deepdiagnostics.utils.config import Config, get_item -from plots import ( - Plots, +from deepdiagnostics.plots import ( CDFRanks, Ranks, CoverageFraction, TARP, - LocalTwoSampleTest, + LC2ST, PPC, PriorPC, Parity @@ -27,16 +25,6 @@ def plot_config(config_factory): config = config_factory(metrics_settings=metrics_settings) return config - -def test_all_defaults(plot_config, mock_model, mock_data): - """ - Ensures each metric has a default set of parameters and is included in the defaults list - Ensures each test can initialize, regardless of the veracity of the output - """ - Config(plot_config) - for plot_name, plot_obj in Plots.items(): - assert plot_name in Defaults["plots"] - plot_obj(mock_model, mock_data, save=True, show=False) def test_plot_cdf(plot_config, mock_model, mock_data): Config(plot_config) @@ -64,11 +52,11 @@ def test_plot_tarp(plot_config, mock_model, mock_data): def test_lc2st(plot_config, mock_model, mock_data, mock_2d_data, result_output): Config(plot_config) - plot = LocalTwoSampleTest(mock_model, mock_data, save=True, show=False) + plot = LC2ST(mock_model, mock_data, save=True, show=False) plot(**get_item("plots", "LC2ST", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") - plot = LocalTwoSampleTest( + plot = LC2ST( mock_model, mock_2d_data, save=True, show=False, out_dir=f"{result_output.strip('/')}/mock_2d/") assert type(plot.data.simulator).__name__ == "Mock2DSimulator" @@ -89,24 +77,17 @@ def test_ppc(plot_config, mock_model, mock_data, mock_2d_data, result_output): plot(**get_item("plots", "PPC", raise_exception=False)) -def test_prior_pc(plot_config, mock_model, mock_data): - Config(plot_config) - plot = PriorPC(mock_model, mock_data, save=True, show=False) - plot(**get_item("plots", "PriorPC", raise_exception=False)) - assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") - - def test_prior_pc(plot_config, mock_model, mock_2d_data, mock_data, result_output): Config(plot_config) plot = PriorPC(mock_model, mock_data, save=True, show=False) plot(**get_item("plots", "PriorPC", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") - plot = PPC( + plot = PriorPC( mock_model, mock_2d_data, save=True, show=False, out_dir=f"{result_output.strip('/')}/mock_2d/") assert type(plot.data.simulator).__name__ == "Mock2DSimulator" - plot(**get_item("plots", "PPC", raise_exception=False)) + plot(**get_item("plots", "PriorPC", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{plot.plot_name}") def test_parity(plot_config, mock_model, mock_data):