Skip to content

Commit

Permalink
Merge pull request #25214 from jakevdp:einsum-optimize
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712603103
  • Loading branch information
Google-ML-Automation committed Jan 6, 2025
2 parents 9f84290 + 2f7204f commit 52cc5c7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`).

* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9503,7 +9503,7 @@ def einsum(
subscript: str, /,
*operands: ArrayLike,
out: None = None,
optimize: str | bool | list[tuple[int, ...]] = "optimal",
optimize: str | bool | list[tuple[int, ...]] = "auto",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
Expand All @@ -9516,7 +9516,7 @@ def einsum(
axes: Sequence[Any], /,
*operands: ArrayLike | Sequence[Any],
out: None = None,
optimize: str | bool | list[tuple[int, ...]] = "optimal",
optimize: str | bool | list[tuple[int, ...]] = "auto",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
Expand All @@ -9528,7 +9528,7 @@ def einsum(
subscripts, /,
*operands,
out: None = None,
optimize: str | bool | list[tuple[int, ...]] = "optimal",
optimize: str | bool | list[tuple[int, ...]] = "auto",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
Expand All @@ -9548,10 +9548,10 @@ def einsum(
subscripts: string containing axes names separated by commas.
*operands: sequence of one or more arrays corresponding to the subscripts.
optimize: specify how to optimize the order of computation. In JAX this defaults
to ``"optimal"`` which produces optimized expressions via the opt_einsum_
to ``"auto"`` which produces optimized expressions via the opt_einsum_
package. Other options are ``True`` (same as ``"optimal"``), ``False``
(unoptimized), or any string supported by ``opt_einsum``, which
includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also
includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also
be a pre-computed path (see :func:`~jax.numpy.einsum_path`).
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
Expand Down

0 comments on commit 52cc5c7

Please sign in to comment.