Skip to content

Commit

Permalink
Models and data documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Jun 21, 2024
1 parent 7e57616 commit 3db2b86
Show file tree
Hide file tree
Showing 15 changed files with 320 additions and 63 deletions.
5 changes: 0 additions & 5 deletions docs/source/API/client.rst

This file was deleted.

6 changes: 0 additions & 6 deletions docs/source/API/utils.rst

This file was deleted.

35 changes: 35 additions & 0 deletions docs/source/client.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
Client
========

.. note::
When running the client, you can supply **either** the configuration yaml file, or the CLI arguments.
You do not need to supply both.

Use the command `diagnose -h` to view all usage of the CLI helper at any time.
Specific argument descriptions and explanations can be found on the :ref:`configuration` page.

.. code-block:: bash
usage: diagnose [-h] [--config CONFIG] [--model_path MODEL_PATH] [--model_engine {SBIModel}] [--data_path DATA_PATH] [--data_engine {H5Data,PickleData}]
[--simulator SIMULATOR] [--out_dir OUT_DIR] [--metrics [{CoverageFraction,AllSBC,LC2ST}]]
[--plots [{CDFRanks,CoverageFraction,Ranks,TARP,LC2ST,PPC}]]
options:
-h, --help show this help message and exit
--config CONFIG, -c CONFIG
.yaml file with all arguments to run.
--model_path MODEL_PATH, -m MODEL_PATH
String path to a model. Must be compatible with your model_engine choice.
--model_engine {SBIModel}, -e {SBIModel}
Way to load your model. See each module's documentation page for requirements and specifications.
--data_path DATA_PATH, -d DATA_PATH
String path to data. Must be compatible with data_engine choice.
--data_engine {H5Data,PickleData}, -g {H5Data,PickleData}
Way to load your data. See each module's documentation page for requirements and specifications.
--simulator SIMULATOR, -s SIMULATOR
String name of the simulator to use with generative metrics and plots. Must be pre-register with the `utils.register_simulator` method.
--out_dir OUT_DIR Where the results will be saved. Path need not exist, it will be created.
--metrics [{CoverageFraction,AllSBC,LC2ST}]
List of metrics to run. To not run any, supply `--metrics `
--plots [{CDFRanks,CoverageFraction,Ranks,TARP,LC2ST,PPC}]
List of plots to run. To not run any, supply `--plots `
10 changes: 6 additions & 4 deletions docs/source/configuration.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _configuration:

Configuration
===============

Expand All @@ -23,7 +25,7 @@ but it can be specified to quickly access variables to avoid re-writing initiali

.. code-block:: python
from DeepDiagnostics.utils.configuration import Config
from deepdiagnostics.utils.configuration import Config
Config("path/to/your/config.yaml")
Expand Down Expand Up @@ -54,7 +56,7 @@ Configuration Description

:param model_path: Path to stored model. Required.

:param model_engine: Loading method to use. Choose from methods listed in :ref:`plots<plots>`
:param model_engine: Loading method to use. Choose from methods listed in :ref:`models`.

.. code-block:: yaml
Expand All @@ -66,11 +68,11 @@ Configuration Description

:param data_path: Path to stored data. Required.

:param data_engine: Loading method to use. Choose from methods listed in :ref:`plots<plots>`
:param data_engine: Loading method to use. Choose from methods listed in :ref:`data`.

:param simulator: String name of the simulator. Must be pre-registered .

:param prior: Prior distribution used in training. Used if "prior" is not included in the passed data. Choose from []
:param prior: Prior distribution used in training. Used if "prior" is not included in the passed data.

:param prior_kwargs: kwargs to use with the initialization of the prior

Expand Down
7 changes: 6 additions & 1 deletion docs/source/API/data.rst → docs/source/data.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _data:

Data
======

Expand All @@ -8,4 +10,7 @@ Data
:members:

.. autoclass:: data.PickleData
:members:
:members:

.. autoclass:: data.simulator.Simulator
:members:
7 changes: 3 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ Welcome to DeepDiagnostics's documentation!
configuration
plots
metrics
API/client
API/utils
API/data
API/models
client
data
models

Indices and tables
==================
Expand Down
2 changes: 2 additions & 0 deletions docs/source/API/models.rst → docs/source/models.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _models:

Models
========

Expand Down
6 changes: 3 additions & 3 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Quickstart
Notebook Example
-----------------

`An example notebook can be found here for an interactive walkthrough. <https://github.com/deepskies/DeepDiagnostics/blob/main/notebooks/example.ipynb>`_.
`An example notebook can be found here for an interactive walkthrough <https://github.com/deepskies/DeepDiagnostics/blob/main/notebooks/example.ipynb>`_.

Installation
--------------
Expand All @@ -27,7 +27,7 @@ Installation
Configuration
----

Description of the configuration file, including defaults, can be found in :ref:`configuration<configuration>`
Description of the configuration file, including defaults, can be found in :ref:`configuration`.

Pipeline
---------
Expand Down Expand Up @@ -75,7 +75,7 @@ All plots and metrics can be found in :ref:`plots<plots>` and :ref:`metrics<metr
Custom Simulations
---
-------------------

To use generative model diagnostics, a simulator has to be included.
This is done by `registering` your simulation with a name and a class associated.
Expand Down
22 changes: 17 additions & 5 deletions src/deepdiagnostics/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,45 @@

def parser():
parser = ArgumentParser()
parser.add_argument("--config", "-c", default=None)
parser.add_argument("--config", "-c", default=None, help=".yaml file with all arguments to run.")

# Model
parser.add_argument("--model_path", "-m", default=None)
parser.add_argument("--model_path", "-m", default=None, help="String path to a model. Must be compatible with your model_engine choice.")
parser.add_argument(
"--model_engine",
"-e",
default=Defaults["model"]["model_engine"],
choices=ModelModules.keys(),
help="Way to load your model. See each module's documentation page for requirements and specifications."
)

# Data
parser.add_argument("--data_path", "-d", default=None)
parser.add_argument("--data_path", "-d", default=None, help="String path to data. Must be compatible with data_engine choice.")
parser.add_argument(
"--data_engine",
"-g",
default=Defaults["data"]["data_engine"],
choices=DataModules.keys(),
help="Way to load your data. See each module's documentation page for requirements and specifications."
)
parser.add_argument("--simulator", "-s", default=None)
parser.add_argument(
"--simulator", "-s",
default=None,
help='String name of the simulator to use with generative metrics and plots. Must be pre-register with the `utils.register_simulator` method.')
# Common
parser.add_argument("--out_dir", default=Defaults["common"]["out_dir"])
parser.add_argument(
"--out_dir",
default=Defaults["common"]["out_dir"],
help="Where the results will be saved. Path need not exist, it will be created."
)

# List of metrics (cannot supply specific kwargs)
parser.add_argument(
"--metrics",
nargs="?",
default=list(Defaults["metrics"].keys()),
choices=Metrics.keys(),
help="List of metrics to run. To not run any, supply `--metrics `"
)

# List of plots
Expand All @@ -49,6 +59,8 @@ def parser():
nargs="?",
default=list(Defaults["plots"].keys()),
choices=Plots.keys(),
help="List of plots to run. To not run any, supply `--plots `"

)

args = parser.parse_args()
Expand Down
112 changes: 103 additions & 9 deletions src/deepdiagnostics/data/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
from typing import Optional
from typing import Any, Optional, Sequence, Union
import numpy as np

from deepdiagnostics.utils.config import get_item
from deepdiagnostics.utils.register import load_simulator

class Data:
"""
Load stored data to use in diagnostics
Args:
path (str): path to the data file.
simulator_name (str): Name of the register simulator. If your simulator is not registered with utils.register_simulator, it will produce an error here.
simulator_kwargs (dict, optional): Any additional kwargs used set up your simulator. Defaults to None.
prior (str, optional): If the prior is not given in the data, use a numpy random distribution. Specified by name. Choose from: {
"normal"
"poisson"
"uniform"
"gamma"
"beta"
"binominal}. Defaults to None.
prior_kwargs (dict, optional): kwargs for the numpy prior. `View this page for a description <https://numpy.org/doc/stable/reference/random/generator.html#distributions>`_. Defaults to None.
simulation_dimensions (Optional[int], optional): 1 or 2. 1->output of the simulator has one dimensions, 2->output has two dimensions (is an image). Defaults to None.
"""
def __init__(
self,
path: str,
Expand All @@ -23,7 +40,13 @@ def __init__(
self.n_dims = self.get_theta_true().shape[1]
self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False)

def get_simulator_output_shape(self):
def get_simulator_output_shape(self) -> tuple[Sequence[int]]:
"""
Run a single sample of the simulator to verify the out-shape.
Returns:
tuple[Sequence[int]]: Output shape of a single sample of the simulator.
"""
context_shape = self.true_context().shape
sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1])
return sim_out.shape
Expand All @@ -32,16 +55,47 @@ def _load(self, path: str):
raise NotImplementedError

def true_context(self):
"""
True data x values, if supplied by the data method.
"""
# From Data
raise NotImplementedError

def true_simulator_outcome(self):
def true_simulator_outcome(self) -> np.ndarray:
"""
Run the simulator on all true theta and true x values.
Returns:
np.ndarray: array of (n samples, simulator shape) showing output of the simulator on all true samples in data.
"""
return self.simulator(self.get_theta_true(), self.true_context())

def sample_prior(self, n_samples: int):
def sample_prior(self, n_samples: int) -> np.ndarray:
"""
Draw samples from the simulator
Args:
n_samples (int): Number of samples to draw
Returns:
np.ndarray:
"""
return self.prior_dist(size=(n_samples, self.n_dims))

def simulator_outcome(self, theta, condition_context=None, n_samples=None):
def simulator_outcome(self, theta:np.ndarray, condition_context:np.ndarray=None, n_samples:int=None):
"""_summary_
Args:
theta (np.ndarray): Theta value of shape (n_samples, theta_dimensions)
condition_context (np.ndarray, optional): If x values for theta are known, use them. Defaults to None.
n_samples (int, optional): If x values are not known for theta, draw them randomly. Defaults to None.
Raises:
ValueError: If either n samples or content samples is supplied.
Returns:
np.ndarray: Simulator output of shape (n samples, simulator_dimensions)
"""
if condition_context is None:
if n_samples is None:
raise ValueError(
Expand All @@ -51,16 +105,39 @@ def simulator_outcome(self, theta, condition_context=None, n_samples=None):
else:
return self.simulator.simulate(theta, condition_context)

def simulated_context(self, n_samples):
def simulated_context(self, n_samples:int) -> np.ndarray:
"""
Call the simulator's `generate_context` method.
Args:
n_samples (int): Number of samples to draw.
Returns:
np.ndarray: context (x values), as defined by the simulator.
"""
return self.simulator.generate_context(n_samples)

def get_theta_true(self):
def get_theta_true(self) -> Union[Any, float, int, np.ndarray]:
"""
Look for the true theta given by data. If supplied in the method, use that, other look in the configuration file.
If neither are supplied, return None.
Returns:
Any: Theta value selected by the search.
"""
if hasattr(self, "theta_true"):
return self.theta_true
else:
return get_item("data", "theta_true", raise_exception=True)

def get_sigma_true(self):
def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]:
"""
Look for the true sigma of data. If supplied in the method, use that, other look in the configuration file.
If neither are supplied, return 1.
Returns:
Any: Sigma value selected by the search.
"""
if hasattr(self, "sigma_true"):
return self.sigma_true()
else:
Expand All @@ -72,7 +149,24 @@ def save(self, data, path: str):
def read_prior(self):
raise NotImplementedError

def load_prior(self, prior, prior_kwargs):
def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable:
"""
Load the prior.
Either try to get it from data (if it has been implemented for the type of data),
or use numpy to initialize a random distribution using the prior argument.
Args:
prior (str): Name of prior.
prior_kwargs (dict[str, any]): kwargs for initializing the prior.
Raises:
NotImplementedError: The selected prior is not included.
RuntimeError: The selected prior is missing arguments to initialize.
Returns:
callable: Prior that can be sampled from by calling it with prior(n_samples)
"""

if prior is None:
prior = get_item("data", "prior", raise_exception=False)
try:
Expand Down
Loading

0 comments on commit 3db2b86

Please sign in to comment.