diff --git a/cloud/batch_cdk/image/poetry.lock b/cloud/batch_cdk/image/poetry.lock index b104689a..c360e8c3 100644 --- a/cloud/batch_cdk/image/poetry.lock +++ b/cloud/batch_cdk/image/poetry.lock @@ -1697,29 +1697,33 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec [[package]] name = "jax" -version = "0.4.6" +version = "0.4.13" description = "Differentiate, compile, and transform Numpy code." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jax-0.4.6.tar.gz", hash = "sha256:d06ea8fba4ed315ec55110396058cb48c8edb2ab0b412f28c8a123beee9e58ab"}, + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, ] [package.dependencies] -numpy = ">=1.20" +ml_dtypes = ">=0.1.0" +numpy = ">=1.21" opt_einsum = "*" -scipy = ">=1.5" +scipy = ">=1.7" [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.4)"] -cpu = ["jaxlib (==0.4.6)"] -cuda = ["jaxlib (==0.4.6+cuda11.cudnn86)"] -cuda11-cudnn82 = ["jaxlib (==0.4.6+cuda11.cudnn82)"] -cuda11-cudnn86 = ["jaxlib (==0.4.6+cuda11.cudnn86)"] -minimum-jaxlib = ["jaxlib (==0.4.4)"] -tpu = ["jaxlib (==0.4.6)", "libtpu-nightly (==0.1.dev20230309)", "requests"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] [[package]] name = "jaxlib" @@ -4783,4 +4787,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "~3.10" -content-hash = "7f388bc931242642f55b1a7eda098ef741565b9b7df83b39c916ef9f04edc897" +content-hash = "fb6fe57abfba58c692429bb29c54f366e7ac8411c21e1f37371ecc4f3dd53dce" diff --git a/cloud/batch_cdk/image/pyproject.toml b/cloud/batch_cdk/image/pyproject.toml index ba859125..f1e51542 100644 --- a/cloud/batch_cdk/image/pyproject.toml +++ b/cloud/batch_cdk/image/pyproject.toml @@ -14,7 +14,7 @@ scipy = "^1.9.3" sympy = "^1.11.1" matplotlib = "^3.6.2" pandas = "^1.5.1" -jax = "0.4.6" +jax = "0.4.13" numpyro = "^0.10.1" jaxlib = [ {version = "0.3.22", platform = "darwin", source="pypi"},