This repository has been archived by the owner on Oct 26, 2024. It is now read-only.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR contains the following updates:
0.4.6
->0.4.13
Release Notes
google/jax
v0.4.13
Compare Source
Changes
jax.jit
now allowsNone
to be passed toin_shardings
andout_shardings
. The semantics are as follows:can change in the future.
determine the output shardings.
jax.experimental.pjit.pjit
also allowsNone
to be passed toin_shardings
andout_shardings
. The semantics are as follows:choose whatever sharding it wants.
can change in the future.
determine the output shardings.
will be replicated on all devices of the mesh.
jaxlib
plugin is in use.jax.tree_util.tree_leaves_with_path
.Bug fixes
is named
cudnn89
instead ofcudnn88
.Deprecations
native_serialization_strict_checks
parameter to{func}
jax.experimental.jax2tf.convert
is deprecated in favor of thenew
native_serializaation_disabled_checks
({jax-issue}#16347
).v0.4.12
Compare Source
Changes
scipy.spatial.transform.Rotation
and {class}scipy.spatial.transform.Slerp
Deprecations
jax.abstract_arrays
and its contents are now deprecated. See relatedfunctionality in :mod:
jax.core
.jax.numpy.alltrue
: usejax.numpy.all
. This follows the deprecationof
numpy.alltrue
in NumPy version 1.25.0.jax.numpy.sometrue
: usejax.numpy.any
. This follows the deprecationof
numpy.sometrue
in NumPy version 1.25.0.jax.numpy.product
: usejax.numpy.prod
. This follows the deprecationof
numpy.product
in NumPy version 1.25.0.jax.numpy.cumproduct
: usejax.numpy.cumprod
. This follows the deprecationof
numpy.cumproduct
in NumPy version 1.25.0.jax.sharding.OpShardingSharding
has been removed since it has been 3months since it was deprecated.
v0.4.11
Compare Source
accordance with the {ref}
api-compatibility
policy:jax.experimental.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.experimental.maps.Mesh
: usejax.sharding.Mesh
jax.experimental.pjit.NamedSharding
: usejax.sharding.NamedSharding
.jax.experimental.pjit.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.experimental.pjit.FROM_GDA
. Instead pass shardedjax.Array
objectsas input and remove the optional
in_shardings
argument topjit
.jax.interpreters.pxla.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.interpreters.pxla.Mesh
: usejax.sharding.Mesh
jax.interpreters.xla.Buffer
: usejax.Array
.jax.interpreters.xla.Device
: usejax.Device
.jax.interpreters.xla.DeviceArray
: usejax.Array
.jax.interpreters.xla.device_put
: usejax.device_put
.jax.interpreters.xla.xla_call_p
: usejax.experimental.pjit.pjit_p
.axis_resources
argument ofwith_sharding_constraint
is removed. Pleaseuse
shardings
instead.v0.4.10
Compare Source
v0.4.9
Compare Source
Changes
experimental_cpp_pmap have been removed.
They are now always on.
(requires jaxlib 0.4.9).
Deprecations
jax.experimental.gda_serialization
is deprecated and has been renamed tojax.experimental.array_serialization
.Please change your imports to use
jax.experimental.array_serialization
.in_axis_resources
andout_axis_resources
arguments of pjit have beendeprecated. Please use
in_shardings
andout_shardings
respectively.jax.numpy.msort
has been removed. It has been deprecated sinceJAX v0.4.1. Use
jnp.sort(a, axis=0)
instead.in_parts
andout_parts
arguments have been removed fromjax.xla_computation
since they were only used with sharded_jit and sharded_jit is long gone.
instantiate_const_outputs
argument has been removed fromjax.xla_computation
since it has been unused for a very long time.
v0.4.8
Compare Source
Breaking changes
A major component of the Cloud TPU runtime has been upgraded. This enables
the following new features on Cloud TPU:
jax.debug.print
, {func}jax.debug.callback
, and{func}
jax.debug.breakpoint()
now work on Cloud TPU{func}
jax.experimental.host_callback
is no longer supported on Cloud TPUwith the new runtime component. Please file an issue on the JAX issue
tracker if the new
jax.debug
APIsare insufficient for your use case.
The old runtime component will be available for at least the next three
months by setting the environment variable
JAX_USE_PJRT_C_API_ON_TPU=false
. If you find you need to disable the newruntime for any reason, please let us know on the JAX issue
tracker.
Changes
Deprecations
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
from source.
global_arg_shapes
argument of pmap only worked with sharded_jit and hasbeen removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.
v0.4.7
Compare Source
Changes
jax.config.jax_array
cannot be disabled anymore.jax.config.jax_jit_pjit_api_merge
cannot be disabled anymore.jax.experimental.jax2tf.convert
now supports thenative_serialization
parameter to use JAX's native lowering to StableHLO to obtain a
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
See documentation.
As part of this change the config flag
--jax2tf_default_experimental_native_lowering
has been renamed to
--jax2tf_native_serialization
.ml_dtypes
, which contains definitions of NumPy typeslike bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
Deprecations
jax.numpy.DeviceArray
is deprecated. Usejax.Array
instead,for which it is an alias.
jax.interpreters.pxla.ShardedDeviceArray
is deprecated. Usejax.Array
instead.jax.numpy.ndarray.at
by position is deprecated.For example, instead of
x.at[i].get(True)
, usex.at[i].get(indices_are_sorted=True)
jax.interpreters.xla.device_put
is deprecated. Please usejax.device_put
.jax.interpreters.pxla.device_put
is deprecated. Please usejax.device_put
.jax.experimental.pjit.FROM_GDA
is deprecated. Please pass in shardedjax.Arrays as input and remove the
in_shardings
argument to pjit sinceit is optional.
Configuration
📅 Schedule: Branch creation - "before 4am on Monday" (UTC), Automerge - At any time (no schedule defined).
🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.
♻ Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
This PR has been generated by Mend Renovate. View repository job log here.