JAX release v0.3.7
- Fixed a performance problem if the indices passed to
jax.numpy.take_along_axis
were broadcasted (#10281). jax.scipy.special.expit
andjax.scipy.special.logit
now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.- The
DeviceArray.tile()
method is deprecated, because numpy arrays do not have atile()
method. As a replacement for this, use jax.numpy.tile (#10266).