From fb53c1571a2ca02ac5a01abbf7a0d6000016dfa2 Mon Sep 17 00:00:00 2001 From: teddygroves Date: Mon, 19 Aug 2024 16:14:24 +0200 Subject: [PATCH] mcmc for methionine model --- pdm.lock | 343 +++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + src/enzax/mcmc.py | 72 ++++++---- 3 files changed, 386 insertions(+), 30 deletions(-) diff --git a/pdm.lock b/pdm.lock index 4db6d7b..671af47 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:bdffa204438cb62ff0bf976ced2ea734cb8800b8f1d4c53b0c631340c8579dd8" +content_hash = "sha256:7ca2553f5416737219a767df3aa5e2cb9290831765343009ebb4be6c440d2d12" [[metadata.targets]] requires_python = ">=3.12" @@ -21,6 +21,30 @@ files = [ {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, ] +[[package]] +name = "arviz" +version = "0.19.0" +requires_python = ">=3.10" +summary = "Exploratory analysis of Bayesian models" +groups = ["default"] +dependencies = [ + "dm-tree>=0.1.8", + "h5netcdf>=1.0.2", + "matplotlib>=3.5", + "numpy>=1.23.0", + "packaging", + "pandas>=1.5.0", + "scipy>=1.9.0", + "setuptools>=60.0.0", + "typing-extensions>=4.1.0", + "xarray-einstats>=0.3", + "xarray>=2022.6.0", +] +files = [ + {file = "arviz-0.19.0-py3-none-any.whl", hash = "sha256:652e739f1931b898922ebe9a9dd393840a0f3f7b803b606938edadc9c263a5bc"}, + {file = "arviz-0.19.0.tar.gz", hash = "sha256:04a369cd05293a17a24fcc1e133f9f350d202cf307f8df6c1dbd319e0f6a9f4f"}, +] + [[package]] name = "blackjax" version = "1.2.3" @@ -83,6 +107,29 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "contourpy" +version = "1.2.1" +requires_python = ">=3.9" +summary = "Python library for calculating contours of 2D quadrilateral grids" +groups = ["default"] +dependencies = [ + "numpy>=1.20", +] +files = [ + {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"}, + {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"}, + {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"}, + {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"}, + {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"}, + {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"}, + {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"}, + {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"}, +] + [[package]] name = "coverage" version = "7.6.1" @@ -168,6 +215,17 @@ files = [ {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, ] +[[package]] +name = "cycler" +version = "0.12.1" +requires_python = ">=3.8" +summary = "Composable style cycles" +groups = ["default"] +files = [ + {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, + {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, +] + [[package]] name = "diffrax" version = "0.6.0" @@ -198,6 +256,22 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] +[[package]] +name = "dm-tree" +version = "0.1.8" +summary = "Tree is a library for working with nested data structures." +groups = ["default"] +files = [ + {file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, + {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, +] + [[package]] name = "equinox" version = "0.11.4" @@ -263,6 +337,57 @@ files = [ {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, ] +[[package]] +name = "fonttools" +version = "4.53.1" +requires_python = ">=3.8" +summary = "Tools to manipulate font files" +groups = ["default"] +files = [ + {file = "fonttools-4.53.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d92d3c2a1b39631a6131c2fa25b5406855f97969b068e7e08413325bc0afba58"}, + {file = "fonttools-4.53.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3b3c8ebafbee8d9002bd8f1195d09ed2bd9ff134ddec37ee8f6a6375e6a4f0e8"}, + {file = "fonttools-4.53.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32f029c095ad66c425b0ee85553d0dc326d45d7059dbc227330fc29b43e8ba60"}, + {file = "fonttools-4.53.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10f5e6c3510b79ea27bb1ebfcc67048cde9ec67afa87c7dd7efa5c700491ac7f"}, + {file = "fonttools-4.53.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f677ce218976496a587ab17140da141557beb91d2a5c1a14212c994093f2eae2"}, + {file = "fonttools-4.53.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9e6ceba2a01b448e36754983d376064730690401da1dd104ddb543519470a15f"}, + {file = "fonttools-4.53.1-cp312-cp312-win32.whl", hash = "sha256:791b31ebbc05197d7aa096bbc7bd76d591f05905d2fd908bf103af4488e60670"}, + {file = "fonttools-4.53.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ed170b5e17da0264b9f6fae86073be3db15fa1bd74061c8331022bca6d09bab"}, + {file = "fonttools-4.53.1-py3-none-any.whl", hash = "sha256:f1f8758a2ad110bd6432203a344269f445a2907dc24ef6bccfd0ac4e14e0d71d"}, + {file = "fonttools-4.53.1.tar.gz", hash = "sha256:e128778a8e9bc11159ce5447f76766cefbd876f44bd79aff030287254e4752c4"}, +] + +[[package]] +name = "h5netcdf" +version = "1.3.0" +requires_python = ">=3.9" +summary = "netCDF4 via h5py" +groups = ["default"] +dependencies = [ + "h5py", + "packaging", +] +files = [ + {file = "h5netcdf-1.3.0-py3-none-any.whl", hash = "sha256:f2df69dcd3665dc9c4d43eb6529dedd113b2508090d12ac973573305a8406465"}, + {file = "h5netcdf-1.3.0.tar.gz", hash = "sha256:a171c027daeb34b24c24a3b6304195b8eabbb6f10c748256ed3cfe19806383cf"}, +] + +[[package]] +name = "h5py" +version = "3.11.0" +requires_python = ">=3.8" +summary = "Read and write HDF5 files from Python" +groups = ["default"] +dependencies = [ + "numpy>=1.17.3", +] +files = [ + {file = "h5py-3.11.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a76cae64080210389a571c7d13c94a1a6cf8cb75153044fd1f822a962c97aeab"}, + {file = "h5py-3.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3736fe21da2b7d8a13fe8fe415f1272d2a1ccdeff4849c1421d2fb30fd533bc"}, + {file = "h5py-3.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa6ae84a14103e8dc19266ef4c3e5d7c00b68f21d07f2966f0ca7bdb6c2761fb"}, + {file = "h5py-3.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:21dbdc5343f53b2e25404673c4f00a3335aef25521bd5fa8c707ec3833934892"}, + {file = "h5py-3.11.0.tar.gz", hash = "sha256:7b7e8f78072a2edec87c9836f25f34203fd492a4475709a18b417a33cfb21fa9"}, +] + [[package]] name = "identify" version = "2.6.0" @@ -355,6 +480,34 @@ files = [ {file = "jaxtyping-0.2.33.tar.gz", hash = "sha256:9a9cfccae4fe05114b9fb27a5ea5440be4971a5a075bbd0526f6dd7d2730f83e"}, ] +[[package]] +name = "kiwisolver" +version = "1.4.5" +requires_python = ">=3.7" +summary = "A fast implementation of the Cassowary constraint solver" +groups = ["default"] +dependencies = [ + "typing-extensions; python_version < \"3.8\"", +] +files = [ + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"}, + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"}, + {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"}, + {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"}, + {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"}, + {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"}, + {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"}, + {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, +] + [[package]] name = "lineax" version = "0.0.5" @@ -372,6 +525,45 @@ files = [ {file = "lineax-0.0.5.tar.gz", hash = "sha256:ed41098dbd94b639287f15682e19cd9ec1dc9ecad6844ee861536044a0e3285c"}, ] +[[package]] +name = "matplotlib" +version = "3.9.2" +requires_python = ">=3.9" +summary = "Python plotting package" +groups = ["default"] +dependencies = [ + "contourpy>=1.0.1", + "cycler>=0.10", + "fonttools>=4.22.0", + "importlib-resources>=3.2.0; python_version < \"3.10\"", + "kiwisolver>=1.3.1", + "numpy>=1.23", + "packaging>=20.0", + "pillow>=8", + "pyparsing>=2.3.1", + "python-dateutil>=2.7", +] +files = [ + {file = "matplotlib-3.9.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c"}, + {file = "matplotlib-3.9.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e"}, + {file = "matplotlib-3.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413"}, + {file = "matplotlib-3.9.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b"}, + {file = "matplotlib-3.9.2-cp313-cp313-win_amd64.whl", hash = "sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c"}, + {file = "matplotlib-3.9.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e"}, + {file = "matplotlib-3.9.2.tar.gz", hash = "sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92"}, +] + [[package]] name = "ml-dtypes" version = "0.4.0" @@ -479,12 +671,69 @@ name = "packaging" version = "24.1" requires_python = ">=3.8" summary = "Core utilities for Python packages" -groups = ["dev"] +groups = ["default", "dev"] files = [ {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] +[[package]] +name = "pandas" +version = "2.2.2" +requires_python = ">=3.9" +summary = "Powerful data structures for data analysis, time series, and statistics" +groups = ["default"] +dependencies = [ + "numpy>=1.22.4; python_version < \"3.11\"", + "numpy>=1.23.2; python_version == \"3.11\"", + "numpy>=1.26.0; python_version >= \"3.12\"", + "python-dateutil>=2.8.2", + "pytz>=2020.1", + "tzdata>=2022.7", +] +files = [ + {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, + {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, + {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, +] + +[[package]] +name = "pillow" +version = "10.4.0" +requires_python = ">=3.8" +summary = "Python Imaging Library (Fork)" +groups = ["default"] +files = [ + {file = "pillow-10.4.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94"}, + {file = "pillow-10.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef"}, + {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a"}, + {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b"}, + {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9"}, + {file = "pillow-10.4.0-cp312-cp312-win32.whl", hash = "sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42"}, + {file = "pillow-10.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a"}, + {file = "pillow-10.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9"}, + {file = "pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3"}, + {file = "pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0"}, + {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc"}, + {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a"}, + {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309"}, + {file = "pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060"}, + {file = "pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea"}, + {file = "pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d"}, + {file = "pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06"}, +] + [[package]] name = "platformdirs" version = "4.2.2" @@ -525,6 +774,17 @@ files = [ {file = "pre_commit-3.8.0.tar.gz", hash = "sha256:8bb6494d4a20423842e198980c9ecf9f96607a07ea29549e180eef9ae80fe7af"}, ] +[[package]] +name = "pyparsing" +version = "3.1.2" +requires_python = ">=3.6.8" +summary = "pyparsing module - Classes and methods to define and execute parsing grammars" +groups = ["default"] +files = [ + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, +] + [[package]] name = "pytest" version = "8.3.2" @@ -559,6 +819,30 @@ files = [ {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, ] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +summary = "Extensions to the standard Python datetime module" +groups = ["default"] +dependencies = [ + "six>=1.5", +] +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[[package]] +name = "pytz" +version = "2024.1" +summary = "World timezone definitions, modern and historical" +groups = ["default"] +files = [ + {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, + {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -614,12 +898,22 @@ version = "72.2.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default"] -marker = "python_version >= \"3.12\"" files = [ {file = "setuptools-72.2.0-py3-none-any.whl", hash = "sha256:f11dd94b7bae3a156a95ec151f24e4637fb4fa19c878e4d191bfb8b2d82728c4"}, {file = "setuptools-72.2.0.tar.gz", hash = "sha256:80aacbf633704e9c8bfa1d99fa5dd4dc59573efcf9e4042c13d3bcef91ac2ef9"}, ] +[[package]] +name = "six" +version = "1.16.0" +requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +summary = "Python 2 and 3 compatibility utilities" +groups = ["default"] +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "toolz" version = "0.12.1" @@ -653,6 +947,17 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "tzdata" +version = "2024.1" +requires_python = ">=2" +summary = "Provider of IANA time zone data" +groups = ["default"] +files = [ + {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, + {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, +] + [[package]] name = "virtualenv" version = "20.26.3" @@ -669,3 +974,35 @@ files = [ {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, ] + +[[package]] +name = "xarray" +version = "2024.7.0" +requires_python = ">=3.9" +summary = "N-D labeled arrays and datasets in Python" +groups = ["default"] +dependencies = [ + "numpy>=1.23", + "packaging>=23.1", + "pandas>=2.0", +] +files = [ + {file = "xarray-2024.7.0-py3-none-any.whl", hash = "sha256:1b0fd51ec408474aa1f4a355d75c00cc1c02bd425d97b2c2e551fd21810e7f64"}, + {file = "xarray-2024.7.0.tar.gz", hash = "sha256:4cae512d121a8522d41e66d942fb06c526bc1fd32c2c181d5fe62fe65b671638"}, +] + +[[package]] +name = "xarray-einstats" +version = "0.7.0" +requires_python = ">=3.9" +summary = "Stats, linear algebra and einops for xarray" +groups = ["default"] +dependencies = [ + "numpy>=1.22", + "scipy>=1.8", + "xarray>=2022.09.0", +] +files = [ + {file = "xarray_einstats-0.7.0-py3-none-any.whl", hash = "sha256:f39403341ebf5b634ab1f1bd0e1bb2dc51046e0df31aa908dfbe2fa6a493712e"}, + {file = "xarray_einstats-0.7.0.tar.gz", hash = "sha256:2d7b571b3bbad3cf2fd10c6c75fd949d247d14c29574184c8489d9d607278d38"}, +] diff --git a/pyproject.toml b/pyproject.toml index 43d59b0..5513fe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "blackjax>=1.2.1", "diffrax>=0.6.0", "jaxtyping>=0.2.31", + "arviz>=0.19.0", ] requires-python = ">=3.12" readme = "README.md" diff --git a/src/enzax/mcmc.py b/src/enzax/mcmc.py index 3be5ed7..bf18618 100644 --- a/src/enzax/mcmc.py +++ b/src/enzax/mcmc.py @@ -1,26 +1,25 @@ """Code for doing mcmc on the parameters of a steady state problem.""" import functools +import logging +import warnings +import arviz as az import blackjax import chex import equinox as eqx import jax import jax.numpy as jnp from jax.scipy.stats import norm +from jaxtyping import Array, Float, ScalarLike from enzax.examples import methionine from enzax.kinetic_model import ( - KineticModelParameters, KineticModel, + KineticModelParameters, UnparameterisedKineticModel, ) -from enzax.rate_equations import ( - AllostericReversibleMichaelisMenten, - ReversibleMichaelisMenten, -) from enzax.steady_state_problem import solve -from jaxtyping import Array, Float, ScalarLike SEED = 1234 jax.config.update("jax_enable_x64", True) @@ -41,11 +40,12 @@ class ObservationSet: class PriorSet: log_kcat: Float[Array, "2 n"] log_enzyme: Float[Array, "2 n"] + log_drain: Float[Array, "2 n_drain"] dgf: Float[Array, "2 n_metabolite"] log_km: Float[Array, "2 n_km"] + log_ki: Float[Array, "2 n_ki"] log_conc_unbalanced: Float[Array, "2 n_unbalanced"] temperature: Float[Array, "2"] - log_ki: Float[Array, " n_ki"] log_transfer_constant: Float[Array, "2 n_allosteric_enzyme"] log_dissociation_constant: Float[Array, "2 n_allosteric_effector"] @@ -80,6 +80,7 @@ def posterior_logdensity_fn( prior_logdensity = ( ind_normal_prior_logdensity(parameters.log_kcat, prior.log_kcat) + ind_normal_prior_logdensity(parameters.log_enzyme, prior.log_enzyme) + + ind_normal_prior_logdensity(parameters.log_drain, prior.log_drain) + ind_normal_prior_logdensity(parameters.dgf, prior.dgf) + ind_normal_prior_logdensity(parameters.log_km, prior.log_km) + ind_normal_prior_logdensity( @@ -101,13 +102,18 @@ def posterior_logdensity_fn( @functools.partial(jax.jit, static_argnames=["kernel", "num_samples"]) def inference_loop(rng_key, kernel, initial_state, num_samples): def one_step(state, rng_key): - state, _ = kernel(rng_key, state) - return state, state + state, info = kernel(rng_key, state) + return state, (state, info) keys = jax.random.split(rng_key, num_samples) - _, states = jax.lax.scan(one_step, initial_state, keys) + _, (states, info) = jax.lax.scan(one_step, initial_state, keys) + + return states, info - return states + +def warn_if_divergent(info): + if jnp.any(info.is_divergent): + warnings.warn("I found a divergent transition!") @eqx.filter_jit @@ -122,47 +128,51 @@ def sample(logdensity_fn, rng_key, init_parameters): target_acceptance_rate=0.95, ) rng_key, warmup_key = jax.random.split(rng_key) - (initial_state, tuned_parameters), _ = warmup.run( + (initial_state, tuned_parameters), (_, info, _) = warmup.run( warmup_key, init_parameters, num_steps=1000, # type: ignore ) rng_key, sample_key = jax.random.split(rng_key) nuts_kernel = blackjax.nuts(logdensity_fn, **tuned_parameters).step - states = inference_loop( + states, info = inference_loop( sample_key, kernel=nuts_kernel, initial_state=initial_state, num_samples=200, ) - return states + return states, info def ind_prior_from_truth(truth, sd): return jnp.vstack((truth, jnp.full(truth.shape, sd))) +def get_idata(samples, info): + sample_dict = { + k: jnp.expand_dims(getattr(samples.position, k), 0) + for k in samples.position.__dataclass_fields__.keys() + } + posterior = az.convert_to_inference_data(sample_dict, group="posterior") + sample_stats = az.convert_to_inference_data( + {"diverging": info.is_divergent}, group="sample_stats" + ) + return az.concat(posterior, sample_stats) + + def main(): """Demonstrate the functionality of the mcmc module.""" true_parameters = methionine.parameters - unparameterised_model = UnparameterisedKineticModel( - structure=methionine.structure, - rate_equation_classes=[ - AllostericReversibleMichaelisMenten, - AllostericReversibleMichaelisMenten, - ReversibleMichaelisMenten, - ], - ) - default_state_guess = jnp.array([0.1, 0.1]) - true_model = KineticModel( - parameters=true_parameters, unparameterised_model=unparameterised_model - ) + unparameterised_model = methionine.unparameterised_model + true_model = methionine.model + default_state_guess = jnp.full((5,), 0.01) true_states = solve( true_parameters, unparameterised_model, default_state_guess ) prior = PriorSet( log_kcat=ind_prior_from_truth(true_parameters.log_kcat, 0.1), log_enzyme=ind_prior_from_truth(true_parameters.log_enzyme, 0.1), + log_drain=ind_prior_from_truth(true_parameters.log_drain, 0.1), dgf=ind_prior_from_truth(true_parameters.dgf, 0.1), log_km=ind_prior_from_truth(true_parameters.log_km, 0.1), log_conc_unbalanced=ind_prior_from_truth( @@ -212,7 +222,15 @@ def main(): unparameterised_model=unparameterised_model, guess=default_state_guess, ) - samples = sample(log_M, key, true_parameters) + samples, info = sample(log_M, key, true_parameters) + idata = get_idata(samples, info) + print(az.summary(idata)) + if jnp.any(info.is_divergentl): + n_divergent = info.is_divergent.sum() + msg = f"There were {n_divergent} post-warmup divergent transitions." + warnings.warn(msg) + else: + logging.info("No post-warmup divergent transitions!") print("True parameter values vs posterior:") for param in true_parameters.__dataclass_fields__.keys(): true_val = getattr(true_parameters, param)