diff --git a/docs/source/API/client.rst b/docs/source/API/client.rst deleted file mode 100644 index 6124f9e..0000000 --- a/docs/source/API/client.rst +++ /dev/null @@ -1,5 +0,0 @@ -Client -======== - -.. automodule:: client - :members: \ No newline at end of file diff --git a/docs/source/API/utils.rst b/docs/source/API/utils.rst deleted file mode 100644 index cc92a44..0000000 --- a/docs/source/API/utils.rst +++ /dev/null @@ -1,6 +0,0 @@ -Utils -======= - -.. autoclass:: utils.config.Config - :members: - diff --git a/docs/source/client.rst b/docs/source/client.rst new file mode 100644 index 0000000..a7cbf56 --- /dev/null +++ b/docs/source/client.rst @@ -0,0 +1,35 @@ +Client +======== + +.. note:: + When running the client, you can supply **either** the configuration yaml file, or the CLI arguments. + You do not need to supply both. + +Use the command `diagnose -h` to view all usage of the CLI helper at any time. +Specific argument descriptions and explanations can be found on the :ref:`configuration` page. + +.. code-block:: bash + + usage: diagnose [-h] [--config CONFIG] [--model_path MODEL_PATH] [--model_engine {SBIModel}] [--data_path DATA_PATH] [--data_engine {H5Data,PickleData}] + [--simulator SIMULATOR] [--out_dir OUT_DIR] [--metrics [{CoverageFraction,AllSBC,LC2ST}]] + [--plots [{CDFRanks,CoverageFraction,Ranks,TARP,LC2ST,PPC}]] + + options: + -h, --help show this help message and exit + --config CONFIG, -c CONFIG + .yaml file with all arguments to run. + --model_path MODEL_PATH, -m MODEL_PATH + String path to a model. Must be compatible with your model_engine choice. + --model_engine {SBIModel}, -e {SBIModel} + Way to load your model. See each module's documentation page for requirements and specifications. + --data_path DATA_PATH, -d DATA_PATH + String path to data. Must be compatible with data_engine choice. + --data_engine {H5Data,PickleData}, -g {H5Data,PickleData} + Way to load your data. See each module's documentation page for requirements and specifications. + --simulator SIMULATOR, -s SIMULATOR + String name of the simulator to use with generative metrics and plots. Must be pre-register with the `utils.register_simulator` method. + --out_dir OUT_DIR Where the results will be saved. Path need not exist, it will be created. + --metrics [{CoverageFraction,AllSBC,LC2ST}] + List of metrics to run. To not run any, supply `--metrics ` + --plots [{CDFRanks,CoverageFraction,Ranks,TARP,LC2ST,PPC}] + List of plots to run. To not run any, supply `--plots ` \ No newline at end of file diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index beb8c13..cb9bae0 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -1,3 +1,5 @@ +.. _configuration: + Configuration =============== @@ -23,7 +25,7 @@ but it can be specified to quickly access variables to avoid re-writing initiali .. code-block:: python - from DeepDiagnostics.utils.configuration import Config + from deepdiagnostics.utils.configuration import Config Config("path/to/your/config.yaml") @@ -54,7 +56,7 @@ Configuration Description :param model_path: Path to stored model. Required. - :param model_engine: Loading method to use. Choose from methods listed in :ref:`plots` + :param model_engine: Loading method to use. Choose from methods listed in :ref:`models`. .. code-block:: yaml @@ -66,11 +68,11 @@ Configuration Description :param data_path: Path to stored data. Required. - :param data_engine: Loading method to use. Choose from methods listed in :ref:`plots` + :param data_engine: Loading method to use. Choose from methods listed in :ref:`data`. :param simulator: String name of the simulator. Must be pre-registered . - :param prior: Prior distribution used in training. Used if "prior" is not included in the passed data. Choose from [] + :param prior: Prior distribution used in training. Used if "prior" is not included in the passed data. :param prior_kwargs: kwargs to use with the initialization of the prior diff --git a/docs/source/API/data.rst b/docs/source/data.rst similarity index 62% rename from docs/source/API/data.rst rename to docs/source/data.rst index 9e66449..c13d14e 100644 --- a/docs/source/API/data.rst +++ b/docs/source/data.rst @@ -1,3 +1,5 @@ +.. _data: + Data ====== @@ -8,4 +10,7 @@ Data :members: .. autoclass:: data.PickleData - :members: \ No newline at end of file + :members: + +.. autoclass:: data.simulator.Simulator + :members: \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 3ae2039..138025d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,10 +14,9 @@ Welcome to DeepDiagnostics's documentation! configuration plots metrics - API/client - API/utils - API/data - API/models + client + data + models Indices and tables ================== diff --git a/docs/source/API/models.rst b/docs/source/models.rst similarity index 89% rename from docs/source/API/models.rst rename to docs/source/models.rst index b732091..e4a124d 100644 --- a/docs/source/API/models.rst +++ b/docs/source/models.rst @@ -1,3 +1,5 @@ +.. _models: + Models ======== diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 8220da3..31f98c9 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -4,7 +4,7 @@ Quickstart Notebook Example ----------------- -`An example notebook can be found here for an interactive walkthrough. `_. +`An example notebook can be found here for an interactive walkthrough `_. Installation -------------- @@ -27,7 +27,7 @@ Installation Configuration ---- -Description of the configuration file, including defaults, can be found in :ref:`configuration` +Description of the configuration file, including defaults, can be found in :ref:`configuration`. Pipeline --------- @@ -75,7 +75,7 @@ All plots and metrics can be found in :ref:`plots` and :ref:`metrics`_. Defaults to None. + simulation_dimensions (Optional[int], optional): 1 or 2. 1->output of the simulator has one dimensions, 2->output has two dimensions (is an image). Defaults to None. + """ def __init__( self, path: str, @@ -23,7 +40,13 @@ def __init__( self.n_dims = self.get_theta_true().shape[1] self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False) - def get_simulator_output_shape(self): + def get_simulator_output_shape(self) -> tuple[Sequence[int]]: + """ + Run a single sample of the simulator to verify the out-shape. + + Returns: + tuple[Sequence[int]]: Output shape of a single sample of the simulator. + """ context_shape = self.true_context().shape sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1]) return sim_out.shape @@ -32,16 +55,47 @@ def _load(self, path: str): raise NotImplementedError def true_context(self): + """ + True data x values, if supplied by the data method. + """ # From Data raise NotImplementedError - def true_simulator_outcome(self): + def true_simulator_outcome(self) -> np.ndarray: + """ + Run the simulator on all true theta and true x values. + + Returns: + np.ndarray: array of (n samples, simulator shape) showing output of the simulator on all true samples in data. + """ return self.simulator(self.get_theta_true(), self.true_context()) - def sample_prior(self, n_samples: int): + def sample_prior(self, n_samples: int) -> np.ndarray: + """ + Draw samples from the simulator + + Args: + n_samples (int): Number of samples to draw + + Returns: + np.ndarray: + """ return self.prior_dist(size=(n_samples, self.n_dims)) - def simulator_outcome(self, theta, condition_context=None, n_samples=None): + def simulator_outcome(self, theta:np.ndarray, condition_context:np.ndarray=None, n_samples:int=None): + """_summary_ + + Args: + theta (np.ndarray): Theta value of shape (n_samples, theta_dimensions) + condition_context (np.ndarray, optional): If x values for theta are known, use them. Defaults to None. + n_samples (int, optional): If x values are not known for theta, draw them randomly. Defaults to None. + + Raises: + ValueError: If either n samples or content samples is supplied. + + Returns: + np.ndarray: Simulator output of shape (n samples, simulator_dimensions) + """ if condition_context is None: if n_samples is None: raise ValueError( @@ -51,16 +105,39 @@ def simulator_outcome(self, theta, condition_context=None, n_samples=None): else: return self.simulator.simulate(theta, condition_context) - def simulated_context(self, n_samples): + def simulated_context(self, n_samples:int) -> np.ndarray: + """ + Call the simulator's `generate_context` method. + + Args: + n_samples (int): Number of samples to draw. + + Returns: + np.ndarray: context (x values), as defined by the simulator. + """ return self.simulator.generate_context(n_samples) - def get_theta_true(self): + def get_theta_true(self) -> Union[Any, float, int, np.ndarray]: + """ + Look for the true theta given by data. If supplied in the method, use that, other look in the configuration file. + If neither are supplied, return None. + + Returns: + Any: Theta value selected by the search. + """ if hasattr(self, "theta_true"): return self.theta_true else: return get_item("data", "theta_true", raise_exception=True) - def get_sigma_true(self): + def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]: + """ + Look for the true sigma of data. If supplied in the method, use that, other look in the configuration file. + If neither are supplied, return 1. + + Returns: + Any: Sigma value selected by the search. + """ if hasattr(self, "sigma_true"): return self.sigma_true() else: @@ -72,7 +149,24 @@ def save(self, data, path: str): def read_prior(self): raise NotImplementedError - def load_prior(self, prior, prior_kwargs): + def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable: + """ + Load the prior. + Either try to get it from data (if it has been implemented for the type of data), + or use numpy to initialize a random distribution using the prior argument. + + Args: + prior (str): Name of prior. + prior_kwargs (dict[str, any]): kwargs for initializing the prior. + + Raises: + NotImplementedError: The selected prior is not included. + RuntimeError: The selected prior is missing arguments to initialize. + + Returns: + callable: Prior that can be sampled from by calling it with prior(n_samples) + """ + if prior is None: prior = get_item("data", "prior", raise_exception=False) try: diff --git a/src/deepdiagnostics/data/h5_data.py b/src/deepdiagnostics/data/h5_data.py index b029144..d975a30 100644 --- a/src/deepdiagnostics/data/h5_data.py +++ b/src/deepdiagnostics/data/h5_data.py @@ -8,13 +8,25 @@ class H5Data(Data): + """ + Load data that has been saved in a h5 format. + + .. attribute:: Data Parameters + + :xs: [REQUIRED] The context, the x values. The data that was used to train a model on what conditions produce what posterior. + :thetas: [REQUIRED] The theta, the parameters of the external model. The data used to train the model's posterior. + :prior: Distribution used to initialize the posterior before training. + :sigma: True standard deviation of the actual thetas, if known. + + """ + def __init__(self, - path: str, - simulator: Callable, - simulator_kwargs: dict = None, - prior: str = None, - prior_kwargs: dict = None, - simulation_dimensions:Optional[int] = None, + path, + simulator, + simulator_kwargs = None, + prior=None, + prior_kwargs = None, + simulation_dimensions = None, ): super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions) @@ -38,17 +50,50 @@ def save(self, data: dict[str, Any], path: str): # Todo typing for data dict file.create_dataset(key, data=value) def true_context(self): - # From Data - return self.data["xs"] # TODO change name + """ + Try to get the `xs` field of the loaded data. + + Raises: + NotImplementedError: The data does not have a `xs` field. + """ + try: + return self.data["xs"] + except KeyError: + raise NotImplementedError("Cannot find `xs` in data. Please supply it.") def prior(self): - # From Data - raise NotImplementedError + """ + If the data has a supplied prior, return it. If not, the data module will default back to picking a prior from a random distribution. + + Raises: + NotImplementedError: The data does not have a `prior` field. + """ + try: + return self.data['prior'] + except KeyError: + raise NotImplementedError("Data does not have a `prior` field.") def get_theta_true(self): - return self.data["thetas"] + """ Get stored theta used to train the model. + + Returns: + theta array + + Raise: + NotImplementedError: Data does not have thetas. + """ + try: + return self.data["thetas"] + except KeyError: + raise NotImplementedError("Data does not have a `thetas` field.") def get_sigma_true(self): + """ + Try to get the true standard deviation of the data. If it is not supplied, return 1. + + Returns: + Any: sigma. + """ try: return super().get_sigma_true() except (AssertionError, KeyError): diff --git a/src/deepdiagnostics/data/pickle_data.py b/src/deepdiagnostics/data/pickle_data.py index 8c9ac4b..0659c02 100644 --- a/src/deepdiagnostics/data/pickle_data.py +++ b/src/deepdiagnostics/data/pickle_data.py @@ -1,26 +1,36 @@ import pickle -from typing import Any, Callable, Optional +from typing import Any from deepdiagnostics.data.data import Data class PickleData(Data): + """ + Load data that is saved as a .pkl file. + """ def __init__(self, - path: str, - simulator: Callable, - simulator_kwargs: dict = None, - prior: str = None, - prior_kwargs: dict = None, - simulation_dimensions:Optional[int] = None, + path, + simulator, + simulator_kwargs = None, + prior = None, + prior_kwargs = None, + simulation_dimensions = None, ): super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions) - def _load(self, path: str): + def _load(self, path: str) -> Any: assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" with open(path, "rb") as file: data = pickle.load(file) return data - def save(self, data: Any, path: str): + def save(self, data: Any, path: str) -> None: + """ + Save data in the form of a .pkl file. + + Args: + data (Any): Data that can be encoded into a pkl. + path (str): Out file path for the data. Must have a .pkl extension. + """ assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" with open(path, "wb") as file: pickle.dump(data, file) diff --git a/src/deepdiagnostics/data/simulator.py b/src/deepdiagnostics/data/simulator.py index 7274086..6b77d77 100644 --- a/src/deepdiagnostics/data/simulator.py +++ b/src/deepdiagnostics/data/simulator.py @@ -14,9 +14,27 @@ def __init__(self) -> None: def generate_context(self, n_samples: int) -> np.ndarray: """ [ABSTRACT, MUST BE FILLED] + Specify how the conditioning context is generated. Can come from data, or from a generic distribution. + Example: + + .. code-block:: python + + # Generate from a random distribution + class MySim(Simulator): + def generate_context(self, n_samples: int) -> np.ndarray: + return np.random.uniform(0, 1) + + # Draw from a sample + class MySim(Simulator): + def __init__(self): + self.data_source = ..... + + def generate_context(self, n_samples: int) -> np.ndarray: + return self.data_source.sample(n_samples) + Args: n_samples (int): Number of samples of context to pull @@ -29,8 +47,21 @@ def generate_context(self, n_samples: int) -> np.ndarray: def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: """ [ABSTRACT, MUST BE FILLED] + Specify a simulation S such that y_{theta} = S(context_samples|theta) + Example: + .. code-block:: python + + # Generate from a random distribution + class MySim(Simulator): + def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray: + simulation_results = np.zeros(theta.shape[0], 1) + for index, context in enumerate(context_samples): + simulation_results[index] = theta[index][0]*context + theta[index][1]*context + + return simulation_results + Args: theta (np.ndarray): Parameters of the simulation model context_samples (np.ndarray): Samples to use with the theta-primed simulation model diff --git a/src/deepdiagnostics/models/model.py b/src/deepdiagnostics/models/model.py index 254d584..46e2c81 100644 --- a/src/deepdiagnostics/models/model.py +++ b/src/deepdiagnostics/models/model.py @@ -1,4 +1,10 @@ class Model: + """ + Load a pre-trained model for analysis. + + Args: + model_path (str): relative path to a model. + """ def __init__(self, model_path: str) -> None: self.model = self._load(model_path) diff --git a/src/deepdiagnostics/models/sbi_model.py b/src/deepdiagnostics/models/sbi_model.py index ea6c603..f062698 100644 --- a/src/deepdiagnostics/models/sbi_model.py +++ b/src/deepdiagnostics/models/sbi_model.py @@ -5,7 +5,14 @@ class SBIModel(Model): - def __init__(self, model_path: str): + """ + Load a trained model that was generated with Mackelab SBI :cite:p:`centero2020sbi`. + `Read more about saving and loading requirements here `_. + + Args: + model_path (str): relative path to a model - must be a .pkl file. + """ + def __init__(self, model_path): super().__init__(model_path) def _load(self, path: str) -> None: @@ -16,14 +23,34 @@ def _load(self, path: str) -> None: posterior = pickle.load(file) self.posterior = posterior - def sample_posterior(self, n_samples: int, y_true): # TODO typing + def sample_posterior(self, n_samples: int, x_true): + """ + Sample the posterior + + Args: + n_samples (int): Number of samples to draw + x_true (np.ndarray): Context samples. (must be dims=(n_samples, M)) + + Returns: + np.ndarray: Posterior samples + """ return self.posterior.sample( - (n_samples,), x=y_true, show_progress_bars=False + (n_samples,), x=x_true, show_progress_bars=False ).cpu() # TODO Unbind from cpu - def predict_posterior(self, data): - posterior_samples = self.sample_posterior(data.y_true) + def predict_posterior(self, data, context_samples): + """ + Sample the posterior and then + + Args: + data (deepdiagnostics.data.Data): Data module with the loaded simulation + context_samples (np.ndarray): X values to test the posterior over. + + Returns: + np.ndarray: Simulator output + """ + posterior_samples = self.sample_posterior(context_samples) posterior_predictive_samples = data.simulator( - data.theta_true(), posterior_samples + posterior_samples, context_samples ) return posterior_predictive_samples