diff --git a/cloud/batch_cdk/image/poetry.lock b/cloud/batch_cdk/image/poetry.lock index b104689a..d563c14a 100644 --- a/cloud/batch_cdk/image/poetry.lock +++ b/cloud/batch_cdk/image/poetry.lock @@ -1697,29 +1697,34 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec [[package]] name = "jax" -version = "0.4.6" +version = "0.4.12" description = "Differentiate, compile, and transform Numpy code." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jax-0.4.6.tar.gz", hash = "sha256:d06ea8fba4ed315ec55110396058cb48c8edb2ab0b412f28c8a123beee9e58ab"}, + {file = "jax-0.4.12.tar.gz", hash = "sha256:d2de9a2388ffe002f16506d3ad1cc6e34d7536b98948e49c7e05bbcfe8e57998"}, ] [package.dependencies] -numpy = ">=1.20" +ml_dtypes = ">=0.1.0" +numpy = ">=1.21" opt_einsum = "*" -scipy = ">=1.5" +scipy = ">=1.7" [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.4)"] -cpu = ["jaxlib (==0.4.6)"] -cuda = ["jaxlib (==0.4.6+cuda11.cudnn86)"] -cuda11-cudnn82 = ["jaxlib (==0.4.6+cuda11.cudnn82)"] -cuda11-cudnn86 = ["jaxlib (==0.4.6+cuda11.cudnn86)"] -minimum-jaxlib = ["jaxlib (==0.4.4)"] -tpu = ["jaxlib (==0.4.6)", "libtpu-nightly (==0.1.dev20230309)", "requests"] +ci = ["jaxlib (==0.4.11)"] +cpu = ["jaxlib (==0.4.12)"] +cuda = ["jaxlib (==0.4.12+cuda11.cudnn86)"] +cuda11-cudnn82 = ["jaxlib (==0.4.12+cuda11.cudnn82)"] +cuda11-cudnn86 = ["jaxlib (==0.4.12+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.12+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.12+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.12+cuda12.cudnn88)"] +cuda12-pip = ["jaxlib (==0.4.12+cuda12.cudnn88)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.12)", "libtpu-nightly (==0.1.dev20230608)"] [[package]] name = "jaxlib" @@ -4783,4 +4788,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "~3.10" -content-hash = "7f388bc931242642f55b1a7eda098ef741565b9b7df83b39c916ef9f04edc897" +content-hash = "2d98caa5c40055dfe9bec898e994f77f1689d12979161c8046bb35cfa2217770" diff --git a/cloud/batch_cdk/image/pyproject.toml b/cloud/batch_cdk/image/pyproject.toml index ba859125..9f1f4337 100644 --- a/cloud/batch_cdk/image/pyproject.toml +++ b/cloud/batch_cdk/image/pyproject.toml @@ -14,7 +14,7 @@ scipy = "^1.9.3" sympy = "^1.11.1" matplotlib = "^3.6.2" pandas = "^1.5.1" -jax = "0.4.6" +jax = "0.4.12" numpyro = "^0.10.1" jaxlib = [ {version = "0.3.22", platform = "darwin", source="pypi"},