diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..d8bad0c --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint it + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 + + - name: Lint code + run: flake8 . + + # Add other steps as needed, such as running additional linters or formatters \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 99bc88f..c37b4bd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -8,6 +8,15 @@ jobs: runs-on: ubuntu-latest steps: + - name: Cache Poetry dependencies + uses: actions/cache@v2 + with: + path: | + ~/.cache + ~/.local/share/virtualenvs + key: ${{ runner.os }}-poetry-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-poetry- - uses: actions/checkout@v2 - name: Set up Python 3.10 uses: actions/setup-python@v2 @@ -27,6 +36,11 @@ jobs: shell: bash run: python -m poetry install + - name: Create Environment File + run: echo "PYTHONPATH=$(pwd):$(pwd)/src" >> ${{ runner.workspace }}/.env + - name: Test with pytest - run: | - python3 -m poetry run pytest --cov + run: python -m poetry run pytest --cov + env: + PYTHONPATH: ${{ env.PYTHONPATH }} + ENV_FILE: ${{ runner.workspace }}/.env diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/poetry.lock b/poetry.lock index 9fbb21d..0cfcdfa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -226,6 +226,52 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "black" +version = "24.1.1" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.8" +files = [ + {file = "black-24.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2588021038bd5ada078de606f2a804cadd0a3cc6a79cb3e9bb3a8bf581325a4c"}, + {file = "black-24.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a95915c98d6e32ca43809d46d932e2abc5f1f7d582ffbe65a5b4d1588af7445"}, + {file = "black-24.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fa6a0e965779c8f2afb286f9ef798df770ba2b6cee063c650b96adec22c056a"}, + {file = "black-24.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:5242ecd9e990aeb995b6d03dc3b2d112d4a78f2083e5a8e86d566340ae80fec4"}, + {file = "black-24.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fc1ec9aa6f4d98d022101e015261c056ddebe3da6a8ccfc2c792cbe0349d48b7"}, + {file = "black-24.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0269dfdea12442022e88043d2910429bed717b2d04523867a85dacce535916b8"}, + {file = "black-24.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3d64db762eae4a5ce04b6e3dd745dcca0fb9560eb931a5be97472e38652a161"}, + {file = "black-24.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:5d7b06ea8816cbd4becfe5f70accae953c53c0e53aa98730ceccb0395520ee5d"}, + {file = "black-24.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e2c8dfa14677f90d976f68e0c923947ae68fa3961d61ee30976c388adc0b02c8"}, + {file = "black-24.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a21725862d0e855ae05da1dd25e3825ed712eaaccef6b03017fe0853a01aa45e"}, + {file = "black-24.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07204d078e25327aad9ed2c64790d681238686bce254c910de640c7cc4fc3aa6"}, + {file = "black-24.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:a83fe522d9698d8f9a101b860b1ee154c1d25f8a82ceb807d319f085b2627c5b"}, + {file = "black-24.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08b34e85170d368c37ca7bf81cf67ac863c9d1963b2c1780c39102187ec8dd62"}, + {file = "black-24.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7258c27115c1e3b5de9ac6c4f9957e3ee2c02c0b39222a24dc7aa03ba0e986f5"}, + {file = "black-24.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40657e1b78212d582a0edecafef133cf1dd02e6677f539b669db4746150d38f6"}, + {file = "black-24.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e298d588744efda02379521a19639ebcd314fba7a49be22136204d7ed1782717"}, + {file = "black-24.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:34afe9da5056aa123b8bfda1664bfe6fb4e9c6f311d8e4a6eb089da9a9173bf9"}, + {file = "black-24.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:854c06fb86fd854140f37fb24dbf10621f5dab9e3b0c29a690ba595e3d543024"}, + {file = "black-24.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3897ae5a21ca132efa219c029cce5e6bfc9c3d34ed7e892113d199c0b1b444a2"}, + {file = "black-24.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:ecba2a15dfb2d97105be74bbfe5128bc5e9fa8477d8c46766505c1dda5883aac"}, + {file = "black-24.1.1-py3-none-any.whl", hash = "sha256:5cdc2e2195212208fbcae579b931407c1fa9997584f0a415421748aeafff1168"}, + {file = "black-24.1.1.tar.gz", hash = "sha256:48b5760dcbfe5cf97fd4fba23946681f3a81514c6ab8a45b50da67ac8fbc6c7b"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "bleach" version = "6.1.0" @@ -425,6 +471,20 @@ files = [ {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"}, ] +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "colorama" version = "0.4.6" @@ -723,6 +783,22 @@ files = [ docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +[[package]] +name = "flake8" +version = "7.0.0" +description = "the modular source code checker: pep8 pyflakes and co" +optional = false +python-versions = ">=3.8.1" +files = [ + {file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"}, + {file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"}, +] + +[package.dependencies] +mccabe = ">=0.7.0,<0.8.0" +pycodestyle = ">=2.11.0,<2.12.0" +pyflakes = ">=3.2.0,<3.3.0" + [[package]] name = "fonttools" version = "4.47.2" @@ -1786,6 +1862,17 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] + [[package]] name = "mistune" version = "3.0.2" @@ -1814,6 +1901,17 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nbclient" version = "0.9.0" @@ -2335,6 +2433,17 @@ files = [ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -2608,6 +2717,17 @@ files = [ [package.dependencies] pyasn1 = ">=0.4.6,<0.6.0" +[[package]] +name = "pycodestyle" +version = "2.11.1" +description = "Python style guide checker" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"}, + {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, +] + [[package]] name = "pycparser" version = "2.21" @@ -2619,6 +2739,17 @@ files = [ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +[[package]] +name = "pyflakes" +version = "3.2.0" +description = "passive checker of Python programs" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, + {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, +] + [[package]] name = "pygments" version = "2.17.2" @@ -3810,4 +3941,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0bb93711856d50808e2e809e075eb3a602027a13c1cd4e7a21020d42c17b6d31" +content-hash = "d48599076a6454177bc9bee98abb863d5cad0559004d1c407de83b7e1200a590" diff --git a/pyproject.toml b/pyproject.toml index aed5895..0ca1395 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ license = "MIT" python = "^3.9" jupyter = "^1.0.0" sbi = "^0.22.0" +flake8 = "^7.0.0" +black = "^24.1.1" [tool.poetry.dev-dependencies] pytest = "^7.3.2" diff --git a/src/scripts/evaluate.py b/src/scripts/evaluate.py index 6c8a103..6f51ab7 100644 --- a/src/scripts/evaluate.py +++ b/src/scripts/evaluate.py @@ -1,6 +1,9 @@ """ -Simple stub functions to use in evaluating inference from a previously trained inference model. +Diagnostic tools for evaluating the quality of the posterior +from a previously trained inference model. +Includes utilities for posterior diagnostics as well as some +inference functions. """ import argparse @@ -14,6 +17,7 @@ import matplotlib import matplotlib.pyplot as plt from cycler import cycler + # remove top and right axis from plots matplotlib.rcParams["axes.spines.right"] = False matplotlib.rcParams["axes.spines.top"] = False @@ -41,7 +45,7 @@ def load_model_pkl(self, path, model_name): :return: Loaded model object that can be used with the predict function """ print(path) - with open(path + model_name + ".pkl", 'rb') as file: + with open(path + model_name + ".pkl", "rb") as file: posterior = pickle.load(file) return posterior @@ -56,36 +60,41 @@ def predict(input, model): :return: Prediction """ return 0 - class Display: - def mackelab_corner_plot(self, - posterior_samples, - labels_list=None, - limit_list=None, - truth_list=None): + def mackelab_corner_plot( + self, + posterior_samples, + labels_list=None, + limit_list=None, + truth_list=None + ): """ - Uses existing pairplot from mackelab analysis to produce a flexible - corner plot. + Uses existing pairplot from mackelab analysis + to produce a flexible corner plot. - :param posterior_samples: Samples drawn from the posterior, conditional on data + :param posterior_samples: Samples drawn from the posterior, + conditional on data :param labels_list: A list of the labels for the parameters :param limit_list: A list of limits for each parameter plot :return: Loaded model object that can be used with the predict function """ # plot posterior samples - #if labels_list: - #if limit_list: + # if labels_list: + # if limit_list: fig, axes = pairplot( posterior_samples, labels_list=labels_list, limits=limit_list, - #[[0,10],[-10,10],[0,10]], + # [[0,10],[-10,10],[0,10]], truths=truth_list, - figsize=(5, 5) + figsize=(5, 5), ) - axes[0, 1].plot([truth_list[1]], [truth_list[0]], marker="o", color="red") + axes[0, 1].plot([truth_list[1]], [truth_list[0]], + marker="o", + color="red") + def improved_corner_plot(self, posterior): """ Improved corner plot @@ -93,7 +102,12 @@ def improved_corner_plot(self, posterior): class Diagnose: - def posterior_predictive(self, theta_true, x_true, simulator, posterior_samples): + def posterior_predictive(self, + theta_true, + x_true, + simulator, + posterior_samples, + true_sigma): # not sure how or where to define the simulator # could require that people input posterior predictive samples, # already drawn from the simulator @@ -104,29 +118,35 @@ def posterior_predictive(self, theta_true, x_true, simulator, posterior_samples) plt.clf() xs_sim = np.linspace(0, 100, 101) ys_sim = np.array(posterior_predictive_samples) - plt.fill_between(xs_sim, - np.mean(ys_sim, axis = 0) - 1 * np.std(ys_sim, axis = 0), - np.mean(ys_sim, axis = 0) + 1 * np.std(ys_sim, axis = 0), - color = '#FF495C', label = 'posterior predictive check with noise') - plt.plot(xs_sim, np.mean(ys_sim, axis = 0) + true_sigma, - color = '#25283D', label = 'true input error') - plt.plot(xs_sim, np.mean(ys_sim, axis = 0) - true_sigma, - color = '#25283D') - plt.scatter(xs_sim, - np.array(y_true), - color = 'black')#'#EFD9CE') - + plt.fill_between( + xs_sim, + np.mean(ys_sim, axis=0) - 1 * np.std(ys_sim, axis=0), + np.mean(ys_sim, axis=0) + 1 * np.std(ys_sim, axis=0), + color="#FF495C", + label="posterior predictive check with noise", + ) + plt.plot( + xs_sim, + np.mean(ys_sim, axis=0) + true_sigma, + color="#25283D", + label="true input error", + ) + plt.plot(xs_sim, np.mean(ys_sim, axis=0) - true_sigma, color="#25283D") + plt.scatter(xs_sim, np.array(y_true), color="black") plt.legend() plt.show() return ys_sim - - def generate_sbc_samples(self, - prior, - posterior, - simulator, - num_sbc_runs=1_000, - num_posterior_samples=1_000): - # generate ground truth parameters and corresponding simulated observations for SBC. + + def generate_sbc_samples( + self, + prior, + posterior, + simulator, + num_sbc_runs=1_000, + num_posterior_samples=1_000, + ): + # generate ground truth parameters + # and corresponding simulated observations for SBC. thetas = prior.sample((num_sbc_runs,)) ys = simulator(thetas) # run SBC: for each inference we draw 1000 posterior samples. @@ -134,39 +154,63 @@ def generate_sbc_samples(self, thetas, ys, posterior, num_posterior_samples=num_posterior_samples ) return thetas, ys, ranks, dap_samples - + def sbc_statistics(self, ranks, thetas, dap_samples, num_posterior_samples): - ''' - The ks pvalues are vanishingly small here, so we can reject the null hypothesis (of the marginal rank distributions being equivalent to an uniform distribution). The inference clearly went wrong. - - In terms of the c2st_ranks diagnostic; this is a nonparametric two sample test from training on and testing on the rank versus uniform distributions and distinguishing between them. If values are close to 0.5, it is hard to distinguish. - - The data-averaged posterior value compares to the prior; if these values are close to 0.5, dap is like the prior distribution. - ''' + """ + The ks pvalues are vanishingly small here, + so we can reject the null hypothesis + (of the marginal rank distributions being equivalent to + an uniform distribution). The inference clearly went wrong. + + In terms of the c2st_ranks diagnostic; + this is a nonparametric two sample test from training on + and testing on the rank versus uniform distributions + and distinguishing between them. If values are close to 0.5, + it is hard to distinguish. + + The data-averaged posterior value compares to the prior; + if these values are close to 0.5, dap is like the prior distribution. + """ check_stats = check_sbc( - ranks, thetas, dap_samples, num_posterior_samples=num_posterior_samples + ranks, + thetas, + dap_samples, + num_posterior_samples=num_posterior_samples ) return check_stats - def plot_1d_ranks(self, - ranks, - num_posterior_samples, - labels_list, - colorlist, - plot=False, - save=True, - path='plots/'): - """ - If the rank plots are consistent with being uniform, the red bars should fall mostly within the grey area. The grey area is the 99% confidence interval for the uniform distribution, so if the rank histogram falls outside this for more than 1 in 100 bars, that means it is not consistent with an uniform distribution (which is what we want). - A central peaked rank plot could be indicative of a posterior that's too concentrated whereas one with wings is a posterior that is too dispersed. Conversely, if the distribution is shifted left or right that indicates that the parameters are biased. - - If it's choppy, it could be indicative of not doing enough sampling. A good rule of thumb is N / B ~ 20, where N is the number of samples and B is the number of bins. + def plot_1d_ranks( + self, + ranks, + num_posterior_samples, + labels_list, + colorlist, + plot=False, + save=True, + path="plots/", + ): + """ + If the rank plots are consistent with being uniform, + the color bars should fall mostly within the grey area. + The grey area is the 99% confidence interval + for the uniform distribution, so if the rank histogram falls + outside this for more than 1 in 100 bars, that means it is not + consistent with an uniform distribution (which is what we want). + + A central peaked rank plot could be indicative of a posterior that's + too concentrated whereas one with wings is a posterior that is + too dispersed. Conversely, if the distribution is shifted left or + right that indicates that the parameters are biased. + + If it's choppy, it could be indicative of not doing enough sampling. + A good rule of thumb is N / B ~ 20, where N is the number of samples + and B is the number of bins. """ - #help(sbc_rank_plot) + # help(sbc_rank_plot) if colorlist: _ = sbc_rank_plot( ranks=ranks, @@ -174,7 +218,7 @@ def plot_1d_ranks(self, plot_type="hist", num_bins=None, parameter_labels=labels_list, - colors=colorlist + colors=colorlist, ) else: _ = sbc_rank_plot( @@ -187,124 +231,152 @@ def plot_1d_ranks(self, if plot: plt.show() if save: - plt.savefig(path+'sbc_ranks.pdf') - - def plot_cdf_1d_ranks(self, - ranks, - num_posterior_samples, - labels_list, - colorlist, - plot=False, - save=True, - path='plots/'): + plt.savefig(path + "sbc_ranks.pdf") + + def plot_cdf_1d_ranks( + self, + ranks, + num_posterior_samples, + labels_list, + colorlist, + plot=False, + save=True, + path="plots/", + ): """ - This is a different way to visualize the same thing from the 1d rank plots. - Essentially, the grey is the 95% confidence interval for an uniform distribution. - The cdf for the posterior rank distributions (in color) should fall within this band. + This is a different way to visualize the same thing + from the 1d rank plots. + Essentially, the grey is the 95% confidence interval for + an uniform distribution. + The cdf for the posterior rank distributions (in color) should fall + within this band. """ help(sbc_rank_plot) if colorlist: - f, ax = sbc_rank_plot(ranks, - num_posterior_samples, - plot_type="cdf", - parameter_labels=labels_list, - colors = colorlist) + f, ax = sbc_rank_plot( + ranks, + num_posterior_samples, + plot_type="cdf", + parameter_labels=labels_list, + colors=colorlist, + ) else: - f, ax = sbc_rank_plot(ranks, - num_posterior_samples, - plot_type="cdf", - parameter_labels=labels_list) + f, ax = sbc_rank_plot( + ranks, + num_posterior_samples, + plot_type="cdf", + parameter_labels=labels_list, + ) if plot: plt.show() if save: - plt.savefig(path+'sbc_ranks_cdf.pdf') - - def calculate_coverage_fraction(self, - posterior, - thetas, - ys, - percentile_list, - samples_per_inference=1_000): + plt.savefig(path + "sbc_ranks_cdf.pdf") + + def calculate_coverage_fraction( + self, + posterior, + thetas, + ys, + percentile_list, + samples_per_inference=1_000 + ): """ posterior --> the trained posterior thetas --> true parameter values ys --> the "observed" data used for inference - + """ # this holds all posterior samples for each inference run - all_samples = np.empty((len(ys), samples_per_inference, np.shape(thetas)[1])) + all_samples = np.empty((len(ys), samples_per_inference, + np.shape(thetas)[1])) count_array = [] # make this for loop into a progress bar: - for i in tqdm(range(len(ys)), desc='Sampling from the posterior for each obs', unit='obs'): - #for i in range(len(ys)): - # sample from the trained posterior n_sample times for each observation - samples = posterior.sample(sample_shape=(samples_per_inference,), - x=ys[i], - show_progress_bars=False).cpu() - - ''' + for i in tqdm( + range(len(ys)), + desc="Sampling from the posterior for each obs", + unit="obs" + ): + # for i in range(len(ys)): + # sample from the trained posterior n_sample times + # for each observation + samples = posterior.sample( + sample_shape=(samples_per_inference,), x=ys[i], + show_progress_bars=False + ).cpu() + + """ # plot posterior samples fig, axes = pairplot( - samples, + samples, labels = ['m', 'b'], #limits = [[0,10],[-10,10],[0,10]], truths = truth_array[i], figsize=(5, 5) ) - axes[0, 1].plot([truth_array[i][1]], [truth_array[i][0]], marker="o", color="r") - ''' - + axes[0, 1].plot([truth_array[i][1]], [truth_array[i][0]], + marker="o", color="r") + """ + all_samples[i] = samples count_vector = [] # step through the percentile list for ind, cov in enumerate(percentile_list): - percentile_l = 50.0 - cov/2 - percentile_u = 50.0 + cov/2 + percentile_l = 50.0 - cov / 2 + percentile_u = 50.0 + cov / 2 # find the percentile for the posterior for this observation # this is n_params dimensional # the units are in parameter space - confidence_l = np.percentile(samples.cpu(), percentile_l, axis=0) - confidence_u = np.percentile(samples.cpu(), percentile_u, axis=0) - # this is asking if the true parameter value is contained between the + confidence_l = np.percentile(samples.cpu(), + percentile_l, + axis=0) + confidence_u = np.percentile(samples.cpu(), + percentile_u, + axis=0) + # this is asking if the true parameter value + # is contained between the # upper and lower confidence intervals # checks separately for each side of the 50th percentile - count = np.logical_and(confidence_u - thetas.T[:,i] > 0, thetas.T[:,i] - confidence_l > 0) + count = np.logical_and( + confidence_u - thetas.T[:, i] > 0, + thetas.T[:, i] - confidence_l > 0 + ) count_vector.append(count) # each time the above is > 0, adds a count count_array.append(count_vector) count_sum_array = np.sum(count_array, axis=0) frac_lens_within_vol = np.array(count_sum_array) - return all_samples, np.array(frac_lens_within_vol)/len(ys) - - - - def plot_coverage_fraction(self, - posterior, - thetas, - ys, - samples_per_inference, - labels_list, - colorlist, - n_percentile_steps=21, - plot=False, - save=True, - path='plots/'): - percentile_array = np.linspace(0,100,n_percentile_steps) - samples, frac_array = self.calculate_coverage_fraction(posterior, - np.array(thetas), - ys, - percentile_array, - samples_per_inference=samples_per_inference) - - - percentile_array_norm = np.array(percentile_array)/100 + return all_samples, np.array(frac_lens_within_vol) / len(ys) + + def plot_coverage_fraction( + self, + posterior, + thetas, + ys, + samples_per_inference, + labels_list, + colorlist, + n_percentile_steps=21, + plot=False, + save=True, + path="plots/", + ): + percentile_array = np.linspace(0, 100, n_percentile_steps) + samples, frac_array = self.calculate_coverage_fraction( + posterior, + np.array(thetas), + ys, + percentile_array, + samples_per_inference=samples_per_inference, + ) + + percentile_array_norm = np.array(percentile_array) / 100 # Create a cycler with hexcode colors and linestyles if colorlist: color_cycler = cycler(color=colorlist) else: - color_cycler = cycler(color='viridis') - linestyle_cycler = cycler(linestyle=['-', '-.']) + color_cycler = cycler(color="viridis") + linestyle_cycler = cycler(linestyle=["-", "-."]) # Plotting fig, ax = plt.subplots(1, 1, figsize=(6, 6)) @@ -314,93 +386,121 @@ def plot_coverage_fraction(self, linestyle_cycle = iter(linestyle_cycler) # Iterate over the second dimension of frac_array for i in range(frac_array.shape[1]): - color_style = next(color_cycle)['color'] - linestyle_style = next(linestyle_cycle)['linestyle'] - ax.plot(percentile_array_norm, - frac_array[:, i], - alpha=1.0, - lw=3, - linestyle=linestyle_style, - color=color_style, - label=labels_list[i]) - - ax.plot([0, 0.5, 1], [0, 0.5, 1], 'k--', lw=3, zorder=1000, label='Reference Line') + color_style = next(color_cycle)["color"] + linestyle_style = next(linestyle_cycle)["linestyle"] + ax.plot( + percentile_array_norm, + frac_array[:, i], + alpha=1.0, + lw=3, + linestyle=linestyle_style, + color=color_style, + label=labels_list[i], + ) + + ax.plot( + [0, 0.5, 1], [0, 0.5, 1], + "k--", lw=3, zorder=1000, label="Reference Line" + ) ax.set_xlim([-0.05, 1.05]) ax.set_ylim([-0.05, 1.05]) - ax.text(0.03, 0.93, 'Underconfident', horizontalalignment='left') - ax.text(0.3, 0.05, 'Overconfident', horizontalalignment='left') - ax.legend(loc='lower right') - ax.set_xlabel('Confidence Interval of the Posterior Volume') - ax.set_ylabel('Fraction of Lenses within Posterior Volume') - ax.set_title('NPE') + ax.text(0.03, 0.93, "Underconfident", horizontalalignment="left") + ax.text(0.3, 0.05, "Overconfident", horizontalalignment="left") + ax.legend(loc="lower right") + ax.set_xlabel("Confidence Interval of the Posterior Volume") + ax.set_ylabel("Fraction of Lenses within Posterior Volume") + ax.set_title("NPE") plt.tight_layout() if plot: plt.show() if save: - plt.savefig(path+'coverage.pdf') - - - def run_all_sbc(self, - prior, - posterior, - simulator, - labels_list, - colorlist, - num_sbc_runs=1_000, - num_posterior_samples=1_000, - samples_per_inference=1_000, - plot=True, - save=False, - path='../plots/'): + plt.savefig(path + "coverage.pdf") + + def run_all_sbc( + self, + prior, + posterior, + simulator, + labels_list, + colorlist, + num_sbc_runs=1_000, + num_posterior_samples=1_000, + samples_per_inference=1_000, + plot=True, + save=False, + path="plots/", + ): """ Runs and displays mackelab's SBC (simulation-based calibration) - Simulation-based calibration is a set of tools built into Mackelab's sbi interface. It provides a way to compare the inferred posterior distribution to the true parameter values. It performs multiple instances of drawing parameter values from the prior, running these through the simulator, and comparing these values to those obtained from the run of inference. Importantly, this will not diagnose what's going on for one draw from the posterior (ie at one data point). Instead, it's meant to give an overall sense of the health of the posterior learned from SBI. - - This technique is based on rank plots. Rank plots are produced from comparing each posterior parameter draw (from the prior) to the distribution of parameter values in the posterior. There should be a 1:1 ranking, aka these rank plots should be similar in shape to a uniform distribution. + Simulation-based calibration is a set of tools built into + Mackelab's sbi interface. It provides a way to compare the + inferred posterior distribution to the true parameter values. + It performs multiple instances of drawing parameter values from + the prior, running these through the simulator, and comparing + these values to those obtained from the run of inference. + Importantly, this will not diagnose what's going on for one draw + from the posterior (ie at one data point). Instead, it's meant to + give an overall sense of the health of the posterior learned from SBI. + + This technique is based on rank plots. + Rank plots are produced from comparing each posterior parameter draw + (from the prior) to the distribution of parameter values in the + posterior. There should be a 1:1 ranking, aka these rank plots should + be similar in shape to a uniform distribution. """ - thetas, ys, ranks, dap_samples = self.generate_sbc_samples(prior, - posterior, - simulator, - num_sbc_runs, - num_posterior_samples) - - stats = self.sbc_statistics(ranks, - thetas, - dap_samples, + thetas, ys, ranks, dap_samples = self.generate_sbc_samples( + prior, posterior, simulator, num_sbc_runs, num_posterior_samples + ) + + stats = self.sbc_statistics(ranks, thetas, dap_samples, num_posterior_samples) print(stats) - self.plot_1d_ranks(ranks, - num_posterior_samples, - labels_list, - colorlist, - plot=plot, - save=save, - path=path) - - self.plot_cdf_1d_ranks(ranks, - num_posterior_samples, - labels_list, - colorlist, - plot=plot, - save=save, - path=path) - - self.plot_coverage_fraction(posterior, - thetas, - ys, - samples_per_inference, - labels_list, - colorlist, - n_percentile_steps=21, - plot=plot, - save=save, - path=path) - - def parameter_1_to_1_plots(samples,): - ''' - We've already saved samples, let's compare the inferred (and associated error bar) parameters from each of the data points we used for the SBC analysis. - ''' + self.plot_1d_ranks( + ranks, + num_posterior_samples, + labels_list, + colorlist, + plot=plot, + save=save, + path=path, + ) + + self.plot_cdf_1d_ranks( + ranks, + num_posterior_samples, + labels_list, + colorlist, + plot=plot, + save=save, + path=path, + ) + + self.plot_coverage_fraction( + posterior, + thetas, + ys, + samples_per_inference, + labels_list, + colorlist, + n_percentile_steps=21, + plot=plot, + save=save, + path=path, + ) + + def parameter_1_to_1_plots( + samples, + thetas, + color_list, + m_color, + b_color, + ): + """ + We've already saved samples, let's compare the inferred + (and associated error bar) parameters from each of the data points we + used for the SBC analysis. + """ print(np.shape(samples), np.shape(thetas)) percentile_16_m = [] @@ -410,64 +510,71 @@ def parameter_1_to_1_plots(samples,): percentile_50_b = [] percentile_84_b = [] for i in range(len(samples[0])): - #print(np.shape(samples[i])) - #STOP - percentile_16_m.append(np.percentile(samples[i,0], 16)) - percentile_50_m.append(np.percentile(samples[i,0], 50)) - percentile_84_m.append(np.percentile(samples[i,0], 84)) - percentile_16_b.append(np.percentile(samples[i,1], 16)) - percentile_50_b.append(np.percentile(samples[i,1], 50)) - percentile_84_b.append(np.percentile(samples[i,1], 84)) - yerr_minus = [mid - low for (mid, low) in zip(percentile_50_m, percentile_16_m)] - yerr_plus = [high - mid for high, mid in zip(percentile_84_m, percentile_50_m)] + # print(np.shape(samples[i])) + # STOP + percentile_16_m.append(np.percentile(samples[i, 0], 16)) + percentile_50_m.append(np.percentile(samples[i, 0], 50)) + percentile_84_m.append(np.percentile(samples[i, 0], 84)) + percentile_16_b.append(np.percentile(samples[i, 1], 16)) + percentile_50_b.append(np.percentile(samples[i, 1], 50)) + percentile_84_b.append(np.percentile(samples[i, 1], 84)) + yerr_minus = [mid - low for (mid, low) in zip(percentile_50_m, + percentile_16_m)] + yerr_plus = [high - mid for high, mid in zip(percentile_84_m, + percentile_50_m)] # Randomly set half of the error bars to zero - random_indices = np.random.choice(len(yerr_minus), int(len(yerr_minus) // 1.15), replace=False) + random_indices = np.random.choice( + len(yerr_minus), int(len(yerr_minus) // 1.15), replace=False + ) for idx in random_indices: yerr_minus[idx] = 0 yerr_plus[idx] = 0 - plt.errorbar(np.array(thetas[:,i]), - percentile_50_m, - yerr = [yerr_minus, yerr_plus], - linestyle = 'None', - color = color_list[i], - capsize = 5) - plt.scatter(np.array(thetas[:,0]), percentile_50_m, color = m_color) - plt.plot(percentile_50_m, percentile_50_m, color = 'k') - plt.xlabel('True value [m]') - plt.ylabel('Recovered value [m]') + plt.errorbar( + np.array(thetas[:, i]), + percentile_50_m, + yerr=[yerr_minus, yerr_plus], + linestyle="None", + color=color_list[i], + capsize=5, + ) + plt.scatter(np.array(thetas[:, 0]), percentile_50_m, color=m_color) + plt.plot(percentile_50_m, percentile_50_m, color="k") + plt.xlabel("True value [m]") + plt.ylabel("Recovered value [m]") plt.show() plt.clf() - plt.scatter(np.array(thetas[:,1]), percentile_50_b, color = b_color) - yerr_minus = [mid - low for (mid, low) in zip(percentile_50_b, percentile_16_b)] - yerr_plus = [high - mid for high, mid in zip(percentile_84_b, percentile_50_b)] + plt.scatter(np.array(thetas[:, 1]), percentile_50_b, color=b_color) + yerr_minus = [mid - low for (mid, low) in zip(percentile_50_b, + percentile_16_b)] + yerr_plus = [high - mid for high, mid in zip(percentile_84_b, + percentile_50_b)] # Randomly set half of the error bars to zero - random_indices = np.random.choice(len(yerr_minus), int(len(yerr_minus) // 1.15), replace=False) + random_indices = np.random.choice( + len(yerr_minus), int(len(yerr_minus) // 1.15), replace=False + ) for idx in random_indices: yerr_minus[idx] = 0 yerr_plus[idx] = 0 - plt.errorbar(np.array(thetas[:,1]), - percentile_50_b, - yerr = [yerr_minus, yerr_plus], - linestyle = 'None', - color = b_color, - capsize = 5) - plt.plot(percentile_50_b, percentile_50_b, color = 'black') - plt.xlabel('True value [b]') - plt.ylabel('Recovered value [b]') + plt.errorbar( + np.array(thetas[:, 1]), + percentile_50_b, + yerr=[yerr_minus, yerr_plus], + linestyle="None", + color=b_color, + capsize=5, + ) + plt.plot(percentile_50_b, percentile_50_b, color="black") + plt.xlabel("True value [b]") + plt.ylabel("Recovered value [b]") plt.show() - - - - - -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--path", type=str, help="path to saved posterior") parser.add_argument("--name", type=str, help="saved posterior name") @@ -479,4 +586,4 @@ def parameter_1_to_1_plots(samples,): # Load the model model = inference_model.load_model_pkl(args.path, args.name) - inference_obj = inference_model.predict(model) \ No newline at end of file + inference_obj = inference_model.predict(model) diff --git a/src/scripts/paths.py b/src/scripts/paths.py index 8c9434e..939c710 100644 --- a/src/scripts/paths.py +++ b/src/scripts/paths.py @@ -25,5 +25,6 @@ # Absolute path to the `src/tex/figures` folder (contains figure output) figures = tex / "figures" -# Absolute path to the `src/tex/output` folder (contains other user-defined output) -output = tex / "output" \ No newline at end of file +# Absolute path to the `src/tex/output` folder +# (contains other user-defined output) +output = tex / "output" diff --git a/src/scripts/train.py b/src/scripts/train.py index a24465b..3bba349 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -1,7 +1,9 @@ """ Simple stubs to use for re-train of the final model -Can leave a default data source, or specify that 'load data' loads the dataset used in the final version +Can leave a default data source, or specify that 'load data' loads the dataset +used in the final version """ + import argparse @@ -11,12 +13,14 @@ def architecture(): """ return 0 + def load_data(data_source): """ :return: data loader or full training data, split in val and train """ return 0, 0 + def train_model(data_source, n_epochs): """ :param data_source: @@ -26,13 +30,17 @@ def train_model(data_source, n_epochs): data = load_data(data_source) model = architecture() - return 0 + return data, model if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--data_source", type=str, help="Data used to train the model") - parser.add_argument("--n_epochs", type=int, help='Integer number of epochs to train the model') + parser.add_argument("--data_source", type=str, + help="Data used to train the model") + parser.add_argument( + "--n_epochs", type=int, + help="Integer number of epochs to train the model" + ) args = parser.parse_args() diff --git a/test/test_example.py b/test/test_example.py deleted file mode 100644 index 897c195..0000000 --- a/test/test_example.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Example of pytest functionality - -Goes over some basic assert examples -""" -import pytest - - -def test_example_assert_equal(): - assert 0 == 0 - - -def test_example_assert_no_equal(): - assert 0 != 1 - - -def test_example_assert_almost_equal(): - assert 1.0 == pytest.approx(1.01, .1) - - -""" -To run this suite of tests, run 'pytest' in the main directory -""" \ No newline at end of file diff --git a/tests/__pycache__/test_evaluate.cpython-39-pytest-7.3.2.pyc b/tests/__pycache__/test_evaluate.cpython-39-pytest-7.3.2.pyc new file mode 100644 index 0000000..7e3ad41 Binary files /dev/null and b/tests/__pycache__/test_evaluate.cpython-39-pytest-7.3.2.pyc differ diff --git a/tests/__pycache__/test_example.cpython-39-pytest-7.3.2.pyc b/tests/__pycache__/test_example.cpython-39-pytest-7.3.2.pyc new file mode 100644 index 0000000..f7ee021 Binary files /dev/null and b/tests/__pycache__/test_example.cpython-39-pytest-7.3.2.pyc differ diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py new file mode 100644 index 0000000..c7daac3 --- /dev/null +++ b/tests/test_evaluate.py @@ -0,0 +1,160 @@ +import sys +import pytest +import torch +import numpy as np +import sbi +import os + +# flake8: noqa +#sys.path.append("..") +from src.scripts.evaluate import Diagnose, InferenceModel +#from src.scripts import evaluate + + +""" +""" + + +""" +Test the evaluate module +""" + + +@pytest.fixture +def diagnose_instance(): + return Diagnose() + + +@pytest.fixture +def inference_instance(): + inference_model = InferenceModel() + path = "savedmodels/sbi/" + model_name = "sbi_linear" + posterior = inference_model.load_model_pkl(path, model_name) + return posterior + + +def simulator(thetas): # , percent_errors): + # convert to numpy array (if tensor): + thetas = np.atleast_2d(thetas) + # Check if the input has the correct shape + if thetas.shape[1] != 2: + raise ValueError( + "Input tensor must have shape (n, 2) \ + where n is the number of parameter sets." + ) + + # Unpack the parameters + if thetas.shape[0] == 1: + # If there's only one set of parameters, extract them directly + m, b = thetas[0, 0], thetas[0, 1] + else: + # If there are multiple sets of parameters, extract them for each row + m, b = thetas[:, 0], thetas[:, 1] + x = np.linspace(0, 100, 101) + rs = np.random.RandomState() # 2147483648)# + # I'm thinking sigma could actually be a function of x + # if we want to get fancy down the road + # Generate random noise (epsilon) based + # on a normal distribution with mean 0 and standard deviation sigma + sigma = 5 + ε = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0])) + + # Initialize an empty array to store the results for each set of parameters + y = np.zeros((len(x), thetas.shape[0])) + for i in range(thetas.shape[0]): + m, b = thetas[i, 0], thetas[i, 1] + y[:, i] = m * x + b + ε[:, i] + return torch.Tensor(y.T) + + +def test_generate_sbc_samples(diagnose_instance, inference_instance): + # Mock data + low_bounds = torch.tensor([0, -10]) + high_bounds = torch.tensor([10, 10]) + + prior = sbi.utils.BoxUniform(low=low_bounds, high=high_bounds) + posterior = inference_instance # provide a mock posterior object + simulator_test = simulator # provide a mock simulator function + num_sbc_runs = 1000 + num_posterior_samples = 1000 + + # Generate SBC samples + thetas, ys, ranks, dap_samples = diagnose_instance.generate_sbc_samples( + prior, posterior, simulator_test, num_sbc_runs, num_posterior_samples + ) + + # Add assertions based on the expected behavior of the method + + +def test_run_all_sbc(diagnose_instance, inference_instance): + labels_list = ["$m$", "$b$"] + colorlist = ["#9C92A3", "#0F5257"] + low_bounds = torch.tensor([0, -10]) + high_bounds = torch.tensor([10, 10]) + + prior = sbi.utils.BoxUniform(low=low_bounds, high=high_bounds) + posterior = inference_instance # provide a mock posterior object + simulator_test = simulator # provide a mock simulator function + + save_path = "plots/" + + diagnose_instance.run_all_sbc( + prior, + posterior, + simulator_test, + labels_list, + colorlist, + num_sbc_runs=1_000, + num_posterior_samples=1_000, + samples_per_inference=1_000, + plot=False, + save=True, + path=save_path, + ) + # Check if PDF files were saved + assert os.path.exists(save_path), f"No 'plots' folder found at {save_path}" + + # List all files in the directory + files_in_directory = os.listdir(save_path) + + # Check if at least one PDF file is present + pdf_files = [file for file in files_in_directory if file.endswith(".pdf")] + assert pdf_files, "No PDF files found in the 'plots' folder" + + # We expect the pdfs to exist in the directory + expected_pdf_files = ["sbc_ranks.pdf", "sbc_ranks_cdf.pdf", "coverage.pdf"] + for expected_file in expected_pdf_files: + assert ( + expected_file in pdf_files + ), f"Expected PDF file '{expected_file}' not found" + + +""" +def test_sbc_statistics(diagnose_instance): + # Mock data + ranks = # provide mock ranks + thetas = # provide mock thetas + dap_samples = # provide mock dap_samples + num_posterior_samples = 1000 + + # Calculate SBC statistics + check_stats = diagnose_instance.sbc_statistics( + ranks, thetas, dap_samples, num_posterior_samples + ) + + # Add assertions based on the expected behavior of the method + +def test_plot_1d_ranks(diagnose_instance): + # Mock data + ranks = # provide mock ranks + num_posterior_samples = 1000 + labels_list = # provide mock labels_list + colorlist = # provide mock colorlist + + # Plot 1D ranks + diagnose_instance.plot_1d_ranks( + ranks, num_posterior_samples, labels_list, + colorlist, plot=False, save=False + ) +"""