diff --git a/setup.py b/setup.py index fab9f729f82b..94d593b08c3b 100644 --- a/setup.py +++ b/setup.py @@ -22,13 +22,13 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.20' +_current_jaxlib_version = '0.4.21' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.20' _available_cuda11_cudnn_versions = ['86'] _default_cuda11_cudnn_version = '86' _default_cuda12_cudnn_version = '89' -_libtpu_version = '0.1.dev20231102' +_libtpu_version = '0.1.dev20231204' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6fd8ec1d507e..aa245e56d9f3 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -20,8 +20,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "96c4f8749b6521ff3e2670d168c9095a5f6323c5" -XLA_SHA256 = "68601c20ab970ca7d35fead91566f847152b1e5a14319e2ffaf9bc35243e1292" +XLA_COMMIT = "96c964be6e99b57d5edbe65e0185bc7898e2e9c1" +XLA_SHA256 = "567182c0dde8d7c80938453ae0fa5904b3d004b4e3906888a2b20be9454731bf" def repo(): tf_http_archive(