Best viewed here.
- GitHub commits.
- New features:
- When combined with jaxlib 0.1.66, {func}
jax.jit
now supports static keyword arguments. A newstatic_argnames
option has been added to specify keyword arguments as static.
- When combined with jaxlib 0.1.66, {func}
- Breaking changes:
- Arguments to {func}
jax.jit
other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added tojit
.
- Arguments to {func}
- Bug fixes:
- The {func}
jax2tf.convert
now works in presence of gradients for functions with integer inputs ({jax-issue}#6360
).
- The {func}
- Added support for static keyword arguments to the C++
jit
implementation.
- GitHub commits.
- New features
- New profiling APIs: {func}
jax.profiler.start_trace
, {func}jax.profiler.stop_trace
, and {func}jax.profiler.trace
- {func}
jax.lax.reduce
is now differentiable.
- New profiling APIs: {func}
- Breaking changes:
- The minimum jaxlib version is now 0.1.64.
- Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
TraceContext
--> {func}~jax.profiler.TraceAnnotation
StepTraceContext
--> {func}~jax.profiler.StepTraceAnnotation
trace_function
--> {func}~jax.profiler.annotate_function
- Omnistaging can no longer be disabled. See omnistaging for more information.
- Python integers larger than the maximum
int64
value will now lead to an overflow in all cases, rather than being silently converted touint64
in some cases ({jax-issue}#6047
). - Outside X64 mode, Python integers outside the range representable by
int32
will now lead to anOverflowError
rather than having their value silently truncated.
- Bug fixes:
host_callback
now supports empty arrays in arguments and results ({jax-issue}#6262
).- {func}
jax.random.randint
clips rather than wraps of out-of-bounds limits, and can now generate integers in the full range of the specified dtype ({jax-issue}#5868
)
-
New features:
-
Bug fixes:
- #6136 generalized
jax.flatten_util.ravel_pytree
to handle integer dtypes. - #6129 fixed a bug with handling
some constants like
enum.IntEnums
- #6145 fixed batching issues with incomplete beta functions
- #6014 fixed H2D transfers during tracing
- #6165 avoids OverflowErrors when converting some large Python integers to floats
- #6136 generalized
-
Breaking changes:
- The minimum jaxlib version is now 0.1.62.
- GitHub commits.
- New features:
- {func}
jax.scipy.stats.chi2
is now available as a distribution with logpdf and pdf methods. - {func}
jax.scipy.stats.betabinom
is now available as a distribution with logpmf and pmf methods. - Added {func}
jax.experimental.jax2tf.call_tf
to call TensorFlow functions from JAX ({jax-issue}#5627
) and README). - Extended the batching rule for
lax.pad
to support batching of the padding values.
- {func}
- Bug fixes:
- {func}
jax.numpy.take
properly handles negative indices ({jax-issue}#5768
)
- {func}
- Breaking changes:
- JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
previously returned afloat64
array, and now returns abfloat16
array. JAX's type promotion behavior is described at {ref}type-promotion
. - {func}
jax.numpy.linspace
now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0. - {func}
jax.numpy.i0
no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0. - Several {mod}
jax.numpy
functions no longer accept tuples or lists in place of array arguments: {func}jax.numpy.pad
, :funcjax.numpy.ravel
, {func}jax.numpy.repeat
, {func}jax.numpy.reshape
. In general, {mod}jax.numpy
functions should be used with scalars or array arguments.
- JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression
- New features:
- jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the
--target_cpu_features
flag tobuild.py
.--target_cpu_features
also replaces--enable_march_native
.
- jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the
- Bug fixes:
- Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
bool
,int8
, anduint8
are now considered safe to cast tobfloat16
NumPy extension type.
- GitHub commits.
- New features:
- Extend the {mod}
jax.experimental.loops
module with support for pytrees. Improved error checking and error messages. - Add {func}
jax.experimental.enable_x64
and {func}jax.experimental.disable_x64
. These are context managers which allow X64 mode to be temporarily enabled/disabled within a session.
- Extend the {mod}
- Breaking changes:
- {func}
jax.ops.segment_sum
now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.
- {func}
- GitHub commits.
- New features:
- Add {func}
jax.closure_convert
for use with higher-order custom derivative functions. ({jax-issue}#5244
) - Add {func}
jax.experimental.host_callback.call
to call a custom Python function on the host and return a result to the device computation. ({jax-issue}#5243
)
- Add {func}
- Bug fixes:
jax.numpy.arccosh
now returns the same branch asnumpy.arccosh
for complex inputs ({jax-issue}#5156
)host_callback.id_tap
now works forjax.pmap
also. There is a optional parameter forid_tap
andid_print
to request that the device from which the value is tapped be passed as a keyword argument to the tap function ({jax-issue}#5182
).
- Breaking changes:
jax.numpy.pad
now takes keyword arguments. Positional argumentconstant_values
has been removed. In addition, passing unsupported keyword arguments raises an error.- Changes for {func}
jax.experimental.host_callback.id_tap
({jax-issue}#5243
):- Removed support for
kwargs
for {func}jax.experimental.host_callback.id_tap
. (This support has been deprecated for a few months.) - Changed the printing of tuples for {func}
jax.experimental.host_callback.id_print
to use '(' instead of '['. - Changed the {func}
jax.experimental.host_callback.id_print
in presence of JVP to print a pair of primal and tangent. Previously, there were two separate print operations for the primals and the tangent. host_callback.outfeed_receiver
has been removed (it is not necessary, and was deprecated a few months ago).
- Removed support for
- New features:
- New flag for debugging
inf
, analagous to that forNaN
({jax-issue}#5224
).
- New flag for debugging
- GitHub commits.
- New features:
- Add
jax.device_put_replicated
- Add multi-host support to
jax.experimental.sharded_jit
- Add support for differentiating eigenvaleus computed by
jax.numpy.linalg.eig
- Add support for building on Windows platforms
- Add support for general in_axes and out_axes in
jax.pmap
- Add complex support for
jax.numpy.linalg.slogdet
- Add
- Bug fixes:
- Fix higher-than-second order derivatives of
jax.numpy.sinc
at zero - Fix some hard-to-hit bugs around symbolic zeros in transpose rules
- Fix higher-than-second order derivatives of
- Breaking changes:
jax.experimental.optix
has been deleted, in favor of the standaloneoptax
Python package.- indexing of JAX arrays with non-tuple sequences now raises a
TypeError
. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See {jax-issue}#4564
.
-
New Features:
- Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
-
Breaking change cleanup
-
Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.
-
Improve consistency of type promotion behavior ({jax-issue}
#4744
):- Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example,
jnp.float32(1) + 1j
now returnscomplex64
, where previously it returnedcomplex128
. - Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
andjnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
both returnfloat16
, where previously the first returnedfloat64
and the second returnedfloat16
.
- Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example,
-
The contents of the (undocumented)
jax.lax_linalg
linear algebra module are now exposed publicly asjax.lax.linalg
. -
jax.random.PRNGKey
now produces the same results in and out of JIT compilation ({jax-issue}#4877
). This required changing the result for a given seed in a few particular cases:- With
jax_enable_x64=False
, negative seeds passed as Python integers now return a different result outside JIT mode. For example,jax.random.PRNGKey(-1)
previously returned[4294967295, 4294967295]
, and now returns[0, 4294967295]
. This matches the behavior in JIT. - Seeds outside the range representable by
int64
outside JIT now result in anOverflowError
rather than aTypeError
. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with
jax_enable_x64=False
outside JIT, you can use:key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
- With
-
DeviceArray now raises
RuntimeError
instead ofValueError
when trying to access its value while it has been deleted.
-
- Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
np.cint
) instead of standard types (e.g.,np.int32
). (#4903) - Fixed a crash when constant-folding certain int16 operations. (#4971)
- Added an
is_leaf
predicate to {func}pytree.flatten
.
- Fixed manylinux2010 compliance issues in GPU wheels.
- Switched the CPU FFT implementation from Eigen to PocketFFT.
- Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
- Add support for retaining ownership when passing arrays to DLPack (#4636).
- Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
- Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
- Fixed a bug in profiler where tools are missing (#4427).
- Dropped support for CUDA 10.0.
- GitHub commits.
- Improvements:
- Ensure that
check_jaxpr
does not perform FLOPS. See {jax-issue}#4650
. - Expanded the set of JAX primitives converted by jax2tf. See primitives_with_limited_support.md.
- Ensure that
-
Improvements:
- Add support for
remat
to jax.experimental.host_callback. See {jax-issue}#4608
.
- Add support for
-
Deprecations
- Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}
#4564
.
- Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}
- GitHub commits.
- The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation
- GitHub commits.
- Improvements:
- As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}
jax.experimental.host_callback.id_print
/ {py:func}jax.experimental.host_callback.id_tap
is not used in the computation.
- As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}
- GitHub commits.
- Improvements:
- Omnistaging on by default. See {jax-issue}
#3370
and omnistaging
- Omnistaging on by default. See {jax-issue}
- Breaking changes:
- New simplified interface for {py:func}
jax.experimental.host_callback.id_tap
(#4101)
- New simplified interface for {py:func}
- Update XLA:
- Fix bug in DLPackManagedTensorToBuffer (#4196)
- GitHub commits.
- Bug Fixes:
- make jnp.abs() work for unsigned inputs (#3914)
- Improvements:
- "Omnistaging" behavior added behind a flag, disabled by default (#3370)
- GitHub commits.
- New Features:
- BFGS (#3101)
- TPU suppot for half-precision arithmetic (#3878)
- Bug Fixes:
- Prevent some accidental dtype warnings (#3874)
- Fix a multi-threading bug in custom derivatives (#3845, #3869)
- Improvements:
- Faster searchsorted implementation (#3873)
- Better test coverage for jax.numpy sorting algorithms (#3836)
- Update XLA.
- GitHub commits.
- The minimum jaxlib version is now 0.1.51.
- New Features:
- jax.image.resize. (#3703)
- hfft and ihfft (#3664)
- jax.numpy.intersect1d (#3726)
- jax.numpy.lexsort (#3812)
lax.scan
and thescan
primitive support anunroll
parameter for loop unrolling when lowering to XLA ({jax-issue}#3738
).
- Bug Fixes:
- Fix reduction repeated axis error (#3618)
- Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
- make psum transpose handle zero cotangents (#3653)
- Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
- Support differentiation through jax.lax.all_to_all (#3733)
- address nan issue in jax.scipy.special.zeta (#3777)
- Improvements:
- Many improvements to jax2tf
- Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
- Enable XLA SPMD partitioning by default. (#3151)
- Add support for 0d transpose convolution (#3643)
- Make LU gradient work for low-rank matrices (#3610)
- support multiple_results and custom JVPs in jet (#3657)
- Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
- Implement complex convolutions on CPU and GPU. (#3735)
- Make jnp.take work for empty slices of empty arrays. (#3751)
- Relax dimension ordering rules for dot_general. (#3778)
- Enable buffer donation for GPU. (#3800)
- Add support for base dilation and window dilation to reduce window op… (#3803)
- Update XLA.
- Add new runtime support for host_callback.
- GitHub commits.
- Bug fixes:
- Fix an odeint bug introduced in the previous release, see
{jax-issue}
#3587
.
- Fix an odeint bug introduced in the previous release, see
{jax-issue}
- GitHub commits.
- The minimum jaxlib version is now 0.1.48.
- Bug fixes:
- Allow
jax.experimental.ode.odeint
dynamics functions to close over values with respect to which we're differentiating {jax-issue}#3562
.
- Allow
- Add support for CUDA 11.0.
- Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
- Update XLA.
- Bug fixes:
- Fix build issue that could result in slow compiles (https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1)
- New features:
- Adds support for fast traceback collection.
- Adds preliminary support for on-device heap profiling.
- Implements
np.nextafter
forbfloat16
types. - Complex128 support for FFTs on CPU and GPU.
- Bugfixes:
- Improved float64
tanh
accuracy on GPU. - float64 scatters on GPU are much faster.
- Complex matrix multiplication on CPU should be much faster.
- Stable sorts on CPU should actually be stable now.
- Concurrency bug fix in CPU backend.
- Improved float64
- GitHub commits.
- New features:
lax.switch
introduces indexed conditionals with multiple branches, together with a generalization of thecond
primitive {jax-issue}#3318
.
- GitHub commits.
- New features:
- {func}
lax.cond
supports a single-operand form, taken as the argument to both branches {jax-issue}#2993
.
- {func}
- Notable changes:
- The format of the
transforms
keyword for the {func}jax.experimental.host_callback.id_tap
primitive has changed {jax-issue}#3132
.
- The format of the
- GitHub commits.
- New features:
- Support for reduction over subsets of a pmapped axis using
axis_index_groups
{jax-issue}#2382
. - Experimental support for printing and calling host-side Python function from
compiled code. See id_print and id_tap
({jax-issue}
#3006
).
- Support for reduction over subsets of a pmapped axis using
- Notable changes:
- The visibility of names exported from {mod}
jax.numpy
has been tightened. This may break code that was making use of names that were previously exported accidentally.
- The visibility of names exported from {mod}
- Fixes crash for outfeed.
- GitHub commits.
- New features:
- Support for
in_axes=None
on {func}pmap
{jax-issue}#2896
.
- Support for
- Fixes crash for linear algebra functions on Mac OS X (#432).
- Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
- GitHub commits.
- New features:
- Differentiation of determinants of singular matrices
{jax-issue}
#2809
.
- Differentiation of determinants of singular matrices
{jax-issue}
- Bug fixes:
- Fix {func}
odeint
differentiation with respect to time of ODEs with time-dependent dynamics {jax-issue}#2817
, also add ODE CI testing. - Fix {func}
lax_linalg.qr
differentiation {jax-issue}#2867
.
- Fix {func}
- Fixes segfault: {jax-issue}
#2755
- Plumb is_stable option on Sort HLO through to Python.
- GitHub commits.
- New features:
- Add syntactic sugar for functional indexed updates
{jax-issue}
#2684
. - Add {func}
jax.numpy.linalg.multi_dot
{jax-issue}#2726
. - Add {func}
jax.numpy.unique
{jax-issue}#2760
. - Add {func}
jax.numpy.rint
{jax-issue}#2724
. - Add {func}
jax.numpy.rint
{jax-issue}#2724
. - Add more primitive rules for {func}
jax.experimental.jet
.
- Add syntactic sugar for functional indexed updates
{jax-issue}
- Bug fixes:
- Fix {func}
logaddexp
and {func}logaddexp2
differentiation at zero {jax-issue}#2107
. - Improve memory usage in reverse-mode autodiff without {func}
jit
{jax-issue}#2719
.
- Fix {func}
- Better errors:
- Improves error message for reverse-mode differentiation of {func}
lax.while_loop
{jax-issue}#2129
.
- Improves error message for reverse-mode differentiation of {func}
- Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
- Bugfix for
batch_group_count
convolutions. - Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
- GitHub commits.
- Added
jax.custom_jvp
andjax.custom_vjp
from {jax-issue}#2026
, see the tutorial notebook. Deprecatedjax.custom_transforms
and removed it from the docs (though it still works). - Add
scipy.sparse.linalg.cg
{jax-issue}#2566
. - Changed how Tracers are printed to show more useful information for debugging {jax-issue}
#2591
. - Made
jax.numpy.isclose
handlenan
andinf
correctly {jax-issue}#2501
. - Added several new rules for
jax.experimental.jet
{jax-issue}#2537
. - Fixed
jax.experimental.stax.BatchNorm
whenscale
/center
isn't provided. - Fix some missing cases of broadcasting in
jax.numpy.einsum
{jax-issue}#2512
. - Implement
jax.numpy.cumsum
andjax.numpy.cumprod
in terms of a parallel prefix scan {jax-issue}#2596
and makereduce_prod
differentiable to arbitray order {jax-issue}#2597
. - Add
batch_group_count
toconv_general_dilated
{jax-issue}#2635
. - Add docstring for
test_util.check_grads
{jax-issue}#2656
. - Add
callback_transform
{jax-issue}#2665
. - Implement
rollaxis
,convolve
/correlate
1d & 2d,copysign
,trunc
,roots
, andquantile
/percentile
interpolation options.
- Fixed a performance regression for Resnet-50 on GPU.
- GitHub commits.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details. - Added an
all_gather
parallel convenience function. - More type annotations in core code.
- jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- GitHub commits.
- Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
- GitHub commits.
- New features:
- {py:func}
jax.pmap
hasstatic_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously tostatic_argnums
in {py:func}jax.jit
. - Improved error messages for when tracers are mistakenly saved in global state.
- Added {py:func}
jax.nn.one_hot
utility function. - Added {mod}
jax.experimental.jet
for exponentially faster higher-order automatic differentiation. - Added more correctness checking to arguments of {py:func}
jax.lax.broadcast_in_dim
.
- {py:func}
- The minimum jaxlib version is now 0.1.41.
- Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
- Includes prototype support for multihost GPU computations that communicate via NCCL.
- Improves performance of NCCL collectives on GPU.
- Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
- Supports device assignments known at XLA compilation time.
-
Breaking changes
- The minimum jaxlib version is now 0.1.38.
- Simplified {py:class}
Jaxpr
by removing theJaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fully-closed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
-
New features:
- Reverse-mode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes ({jax-issue}#2091
) - JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
- JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba. - JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
- Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
- Reverse-mode automatic differentiation (e.g.
- Updates XLA.
- CUDA 9.0 is no longer supported.
- CUDA 10.2 wheels are now built by default.
-
Breaking changes
- JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
-
New features
- Forward-mode automatic differentiation (
jvp
) of while loop ({jax-issue}#1980
)
-
New NumPy and SciPy functions:
- {py:func}
jax.numpy.fft.fft2
- {py:func}
jax.numpy.fft.ifft2
- {py:func}
jax.numpy.fft.rfft
- {py:func}
jax.numpy.fft.irfft
- {py:func}
jax.numpy.fft.rfft2
- {py:func}
jax.numpy.fft.irfft2
- {py:func}
jax.numpy.fft.rfftn
- {py:func}
jax.numpy.fft.irfftn
- {py:func}
jax.numpy.fft.fftfreq
- {py:func}
jax.numpy.fft.rfftfreq
- {py:func}
jax.numpy.linalg.matrix_rank
- {py:func}
jax.numpy.linalg.matrix_power
- {py:func}
jax.scipy.special.betainc
- {py:func}
-
Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
- Forward-mode automatic differentiation (
- With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.