Skip to content

JAX release v0.3.7

Compare
Choose a tag to compare
@hawkinsp hawkinsp released this 29 Apr 18:09
· 13754 commits to main since this release
  • Fixed a performance problem if the indices passed to jax.numpy.take_along_axis were broadcasted (#10281).
  • jax.scipy.special.expit and jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
  • The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use jax.numpy.tile (#10266).