Releases: jax-ml/jax
Releases · jax-ml/jax
JAX release v0.3.7
- Fixed a performance problem if the indices passed to
jax.numpy.take_along_axis
were broadcasted (#10281). jax.scipy.special.expit
andjax.scipy.special.logit
now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.- The
DeviceArray.tile()
method is deprecated, because numpy arrays do not have atile()
method. As a replacement for this, use jax.numpy.tile (#10266).
JAX release v0.3.6
- Changes:
- Upgraded libtpu wheel to the fixed version. Fixes #10218.
JAX release v0.3.5
Changes
- added
jax.random.loggamma
& improved behavior ofjax.random.beta
andjax.random.dirichlet
for small parameter values (#9906). - the private
lax_numpy
submodule is no longer exposed in thejax.numpy
namespace (#10029). - added array creation routines
jax.numpy.frombuffer
,jax.numpy.fromfunction
,
andjax.numpy.fromstring
(#10049). DeviceArray.copy()
now returns aDeviceArray
rather than anp.ndarray
(#10069)- added
jax.scipy.linalg.rsf2csf
- Deprecations:
JAX release v0.3.4
Fix a bug introduced in #9923.
JAX release v0.3.3
Jax release v0.3.1
- Changes:
jax.test_util.JaxTestCase
andjax.test_util.JaxTestLoader
are now deprecated.
The suggested replacement is to useparametrized.TestCase
directly. For tests that
rely on custom asserts such asJaxTestCase.assertAllClose()
, the suggested replacement
is to use standard numpy testing utilities such asnumpy.testing.assert_allclose()
,
which work directly with JAX arrays (#9620 ).jax.test_util.JaxTestCase
now setsjax_numpy_rank_promotion='raise'
by default
(#9562 ). To recover the previous behavior, use the new
jax.test_util.with_config
decorator:@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
- Added
jax.scipy.linalg.schur
,jax.scipy.linalg.sqrtm
,
jax.scipy.signal.csd
,jax.scipy.signal.stft
,
jax.scipy.signal.welch
.
Jaxlib release v0.3.0
- Changes
- Bazel 5.0.0 is now required to build jaxlib.
- jaxlib version has been bumped to 0.3.0. Please see the design doc
for the explanation.
Jax release v0.3.0
- Changes
- jax version has been bumped to 0.3.0. Please see the design doc
for the explanation.
- jax version has been bumped to 0.3.0. Please see the design doc
JAX release v0.2.28
- GitHub commits.
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if no
dialect=
is passed.- The
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIR
ir.Module
object instead of its string representation.
Jaxlib v0.1.76
- New features
- Includes precompiled SASS for NVidia compute capability 8.0 GPUS
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
to increase the number of compute capabilities: GPUs with compute capability
6.1 can use the 6.0 SASS. - With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR
by default.
- Includes precompiled SASS for NVidia compute capability 8.0 GPUS
- Breaking changes
- Support for NumPy 1.18 has been dropped, per the deprecation policy.
Please upgrade to a supported NumPy version.
- Support for NumPy 1.18 has been dropped, per the deprecation policy.
- Bug fixes