Skip to content
This repository has been archived by the owner on Oct 26, 2024. It is now read-only.

Update dependency jax to v0.4.13 #334

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Update dependency jax to v0.4.13 #334

wants to merge 1 commit into from

Conversation

renovate[bot]
Copy link

@renovate renovate bot commented Apr 24, 2023

Mend Renovate

This PR contains the following updates:

Package Change Age Adoption Passing Confidence
jax 0.4.6 -> 0.4.13 age adoption passing confidence

Release Notes

google/jax

v0.4.13

Compare Source

  • Changes

    • jax.jit now allows None to be passed to in_shardings and
      out_shardings. The semantics are as follows:
      • For in_shardings, JAX will mark is as replicated but this behavior
        can change in the future.
      • For out_shardings, we will rely on the XLA GSPMD partitioner to
        determine the output shardings.
    • jax.experimental.pjit.pjit also allows None to be passed to
      in_shardings and out_shardings. The semantics are as follows:
      • If the mesh context manager is not provided, JAX has the freedom to
        choose whatever sharding it wants.
        • For in_shardings, JAX will mark is as replicated but this behavior
          can change in the future.
        • For out_shardings, we will rely on the XLA GSPMD partitioner to
          determine the output shardings.
      • If the mesh context manager is provided, None will imply that the value
        will be replicated on all devices of the mesh.
    • Executable.cost_analysis() works on Cloud TPU
    • Added a warning if a non-allowlisted jaxlib plugin is in use.
    • Added jax.tree_util.tree_leaves_with_path.
  • Bug fixes

    • Fixed incorrect wheel name in CUDA 12 releases (#​16362); the correct wheel
      is named cudnn89 instead of cudnn88.
  • Deprecations

    • The native_serialization_strict_checks parameter to
      {func}jax.experimental.jax2tf.convert is deprecated in favor of the
      new native_serializaation_disabled_checks ({jax-issue}#16347).

v0.4.12

Compare Source

  • Changes

    • Added {class}scipy.spatial.transform.Rotation and {class}scipy.spatial.transform.Slerp
  • Deprecations

    • jax.abstract_arrays and its contents are now deprecated. See related
      functionality in :mod:jax.core.
    • jax.numpy.alltrue: use jax.numpy.all. This follows the deprecation
      of numpy.alltrue in NumPy version 1.25.0.
    • jax.numpy.sometrue: use jax.numpy.any. This follows the deprecation
      of numpy.sometrue in NumPy version 1.25.0.
    • jax.numpy.product: use jax.numpy.prod. This follows the deprecation
      of numpy.product in NumPy version 1.25.0.
    • jax.numpy.cumproduct: use jax.numpy.cumprod. This follows the deprecation
      of numpy.cumproduct in NumPy version 1.25.0.
    • jax.sharding.OpShardingSharding has been removed since it has been 3
      months since it was deprecated.

v0.4.11

Compare Source

  • Deprecations
    • The following APIs have been removed after a 3 month deprecation period, in
      accordance with the {ref}api-compatibility policy:
      • jax.experimental.PartitionSpec: use jax.sharding.PartitionSpec.
      • jax.experimental.maps.Mesh: use jax.sharding.Mesh
      • jax.experimental.pjit.NamedSharding: use jax.sharding.NamedSharding.
      • jax.experimental.pjit.PartitionSpec: use jax.sharding.PartitionSpec.
      • jax.experimental.pjit.FROM_GDA. Instead pass sharded jax.Array objects
        as input and remove the optional in_shardings argument to pjit.
      • jax.interpreters.pxla.PartitionSpec: use jax.sharding.PartitionSpec.
      • jax.interpreters.pxla.Mesh: use jax.sharding.Mesh
      • jax.interpreters.xla.Buffer: use jax.Array.
      • jax.interpreters.xla.Device: use jax.Device.
      • jax.interpreters.xla.DeviceArray: use jax.Array.
      • jax.interpreters.xla.device_put: use jax.device_put.
      • jax.interpreters.xla.xla_call_p: use jax.experimental.pjit.pjit_p.
      • axis_resources argument of with_sharding_constraint is removed. Please
        use shardings instead.

v0.4.10

Compare Source

v0.4.9

Compare Source

  • Changes

    • The flags experimental_cpp_jit, experimental_cpp_pjit and
      experimental_cpp_pmap have been removed.
      They are now always on.
    • Accuracy of singular value decomposition (SVD) on TPU has been improved
      (requires jaxlib 0.4.9).
  • Deprecations

    • jax.experimental.gda_serialization is deprecated and has been renamed to
      jax.experimental.array_serialization.
      Please change your imports to use jax.experimental.array_serialization.
    • The in_axis_resources and out_axis_resources arguments of pjit have been
      deprecated. Please use in_shardings and out_shardings respectively.
    • The function jax.numpy.msort has been removed. It has been deprecated since
      JAX v0.4.1. Use jnp.sort(a, axis=0) instead.
    • in_parts and out_parts arguments have been removed from jax.xla_computation
      since they were only used with sharded_jit and sharded_jit is long gone.
    • instantiate_const_outputs argument has been removed from jax.xla_computation
      since it has been unused for a very long time.

v0.4.8

Compare Source

  • Breaking changes

    • A major component of the Cloud TPU runtime has been upgraded. This enables
      the following new features on Cloud TPU:

      • {func}jax.debug.print, {func}jax.debug.callback, and
        {func}jax.debug.breakpoint() now work on Cloud TPU
      • Automatic TPU memory defragmentation

      {func}jax.experimental.host_callback is no longer supported on Cloud TPU
      with the new runtime component. Please file an issue on the JAX issue
      tracker
      if the new jax.debug APIs
      are insufficient for your use case.

      The old runtime component will be available for at least the next three
      months by setting the environment variable
      JAX_USE_PJRT_C_API_ON_TPU=false. If you find you need to disable the new
      runtime for any reason, please let us know on the JAX issue
      tracker
      .

  • Changes

    • The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
  • Deprecations

    • CUDA 11.4 support has been dropped. JAX GPU wheels only support
      CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
      from source.
    • global_arg_shapes argument of pmap only worked with sharded_jit and has
      been removed from pmap. Please migrate to pjit and remove global_arg_shapes
      from pmap.

v0.4.7

Compare Source

  • Changes

    • As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
      jax.config.jax_array cannot be disabled anymore.
    • jax.config.jax_jit_pjit_api_merge cannot be disabled anymore.
    • {func}jax.experimental.jax2tf.convert now supports the native_serialization
      parameter to use JAX's native lowering to StableHLO to obtain a
      StableHLO module for the entire JAX function instead of lowering each JAX
      primitive to a TensorFlow op. This simplifies the internals and increases
      the confidence that what you serialize matches the JAX native semantics.
      See documentation.
      As part of this change the config flag --jax2tf_default_experimental_native_lowering
      has been renamed to --jax2tf_native_serialization.
    • JAX now depends on ml_dtypes, which contains definitions of NumPy types
      like bfloat16. These definitions were previously internal to JAX, but have
      been split into a separate package to facilitate sharing them with other
      projects.
    • JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
  • Deprecations

    • The type jax.numpy.DeviceArray is deprecated. Use jax.Array instead,
      for which it is an alias.
    • The type jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use
      jax.Array instead.
    • Passing additional arguments to {func}jax.numpy.ndarray.at by position is deprecated.
      For example, instead of x.at[i].get(True), use x.at[i].get(indices_are_sorted=True)
    • jax.interpreters.xla.device_put is deprecated. Please use jax.device_put.
    • jax.interpreters.pxla.device_put is deprecated. Please use jax.device_put.
    • jax.experimental.pjit.FROM_GDA is deprecated. Please pass in sharded
      jax.Arrays as input and remove the in_shardings argument to pjit since
      it is optional.

Configuration

📅 Schedule: Branch creation - "before 4am on Monday" (UTC), Automerge - At any time (no schedule defined).

🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.

Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.

🔕 Ignore: Close this PR and you won't be reminded about this update again.


  • If you want to rebase/retry this PR, check this box

This PR has been generated by Mend Renovate. View repository job log here.

@renovate renovate bot changed the title Update dependency jax to v0.4.8 Update dependency jax to v0.4.9 May 10, 2023
@renovate renovate bot changed the title Update dependency jax to v0.4.9 Update dependency jax to v0.4.10 May 12, 2023
@renovate renovate bot changed the title Update dependency jax to v0.4.10 Update dependency jax to v0.4.11 Jun 1, 2023
@renovate renovate bot changed the title Update dependency jax to v0.4.11 Update dependency jax to v0.4.12 Jun 9, 2023
@renovate renovate bot changed the title Update dependency jax to v0.4.12 Update dependency jax to v0.4.13 Jun 23, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0 participants