Skip to content

Releases: jax-ml/jax

JAX release v0.3.7

29 Apr 18:09
Compare
Choose a tag to compare
  • Fixed a performance problem if the indices passed to jax.numpy.take_along_axis were broadcasted (#10281).
  • jax.scipy.special.expit and jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
  • The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use jax.numpy.tile (#10266).

JAX release v0.3.6

13 Apr 00:52
Compare
Choose a tag to compare
  • Changes:
    • Upgraded libtpu wheel to the fixed version. Fixes #10218.

JAX release v0.3.5

07 Apr 20:29
Compare
Choose a tag to compare

Changes

  • added jax.random.loggamma & improved behavior of jax.random.beta
    and jax.random.dirichlet for small parameter values (#9906).
  • the private lax_numpy submodule is no longer exposed in the jax.numpy namespace (#10029).
  • added array creation routines jax.numpy.frombuffer, jax.numpy.fromfunction,
    and jax.numpy.fromstring (#10049).
  • DeviceArray.copy() now returns a DeviceArray rather than a np.ndarray (#10069)
  • added jax.scipy.linalg.rsf2csf
  • Deprecations:
    • jax.nn.normalize is being deprecated. Use jax.nn.standardize instead (#9899).
    • jax.tree_util.tree_multimap is deprecated. Use jax.tree_util.tree_map instead (#5746).
    • jax.experimental.sharded_jit is deprecated. Use pjit instead.

JAX release v0.3.4

18 Mar 21:13
Compare
Choose a tag to compare

Fix a bug introduced in #9923.

JAX release v0.3.3

17 Mar 22:31
Compare
Choose a tag to compare

Jax release v0.3.1

18 Feb 22:36
Compare
Choose a tag to compare
  • Changes:
    • jax.test_util.JaxTestCase and jax.test_util.JaxTestLoader are now deprecated.
      The suggested replacement is to use parametrized.TestCase directly. For tests that
      rely on custom asserts such as JaxTestCase.assertAllClose(), the suggested replacement
      is to use standard numpy testing utilities such as numpy.testing.assert_allclose(),
      which work directly with JAX arrays (#9620 ).
    • jax.test_util.JaxTestCase now sets jax_numpy_rank_promotion='raise' by default
      (#9562 ). To recover the previous behavior, use the new
      jax.test_util.with_config decorator:
      @jtu.with_config(jax_numpy_rank_promotion='allow')
      class MyTestCase(jtu.JaxTestCase):
        ...
    • Added jax.scipy.linalg.schur, jax.scipy.linalg.sqrtm,
      jax.scipy.signal.csd, jax.scipy.signal.stft,
      jax.scipy.signal.welch.

Jaxlib release v0.3.0

10 Feb 20:07
Compare
Choose a tag to compare
  • Changes
    • Bazel 5.0.0 is now required to build jaxlib.
    • jaxlib version has been bumped to 0.3.0. Please see the design doc
      for the explanation.

Jax release v0.3.0

10 Feb 20:07
Compare
Choose a tag to compare
  • Changes
    • jax version has been bumped to 0.3.0. Please see the design doc
      for the explanation.

JAX release v0.2.28

02 Feb 01:55
Compare
Choose a tag to compare
  • GitHub commits.
    • jax.jit(f).lower(...).compiler_ir() now defaults to the MHLO dialect if no
      dialect= is passed.
    • The jax.jit(f).lower(...).compiler_ir(dialect='mhlo') now returns an MLIR
      ir.Module object instead of its string representation.

Jaxlib v0.1.76

28 Jan 15:22
Compare
Choose a tag to compare
  • New features
    • Includes precompiled SASS for NVidia compute capability 8.0 GPUS
      (e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
      to increase the number of compute capabilities: GPUs with compute capability
      6.1 can use the 6.0 SASS.
    • With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR
      by default.
  • Breaking changes
    • Support for NumPy 1.18 has been dropped, per the deprecation policy.
      Please upgrade to a supported NumPy version.
  • Bug fixes
    • Fixed a bug where apparently identical pytreedef objects constructed by different routes
      do not compare as equal (#9066).
    • The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).