Skip to content

Commit

Permalink
Prepare for v0.4.33 release.
Browse files Browse the repository at this point in the history
This release is branched off the v0.4.32 release, with two changes:
a) a fixed libtpu pin, and
b) a patch to revert an F64 tanh issue on CPU.
  • Loading branch information
hawkinsp committed Sep 16, 2024
1 parent 1594d2f commit 80e1c94
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 6 deletions.
24 changes: 22 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,24 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.32
## jax 0.4.33

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of `libtpu`.

## jaxlib 0.4.33

This release fixes an inaccurate result for F64 tanh on CPU (#23590).

## jax 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
Expand Down Expand Up @@ -65,7 +82,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
another framework that implements the ``__dlpack__`` protocol.

## jaxlib 0.4.32
## jaxlib 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* Breaking changes
* Hermetic CUDA support is added.
Expand Down
4 changes: 2 additions & 2 deletions jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pathlib
import subprocess

_version = "0.4.32"
_version = "0.4.33"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
Expand Down Expand Up @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files):


__version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.32"
_minimum_jaxlib_version = "0.4.33"

def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

project_name = 'jax'

_current_jaxlib_version = '0.4.32'
_current_jaxlib_version = '0.4.33'
# The following should be updated after each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.31'
_libtpu_version = '0.1.dev20240911'
_libtpu_version = '0.1.dev20240916'

def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/tanh.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/xla/service/cpu/llvm_ir_runtime.cc b/xla/service/cpu/llvm_ir_runtime.cc
index 89b40b915caa3..25541c16bfd61 100644
--- a/xla/service/cpu/llvm_ir_runtime.cc
+++ b/xla/service/cpu/llvm_ir_runtime.cc
@@ -410,7 +410,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module,
rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
rewrite_calls(kTanhV16F32SymbolName, GenerateVF32Tanh, /*vector_width=*/16);

- rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1);
+ // TODO(penporn): Re-enable after fixing JAX issue #23590.
+ // rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1);

rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);
3 changes: 3 additions & 0 deletions third_party/xla/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def repo():
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
patch_file = [
"//third_party/xla:tanh.patch",
],
)

# For development, one often wants to make changes to the TF repository as well
Expand Down

0 comments on commit 80e1c94

Please sign in to comment.