diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ffdafbb1cba..356e16bd4d4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Changes: * The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025. + * {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than + `optimize='optimal'`. This avoids exponentially-scaling trace-time in + the case of many arguments ({jax-issue}`#25214`). * New Features * {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f232612526f4..2681cbc81283 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9503,7 +9503,7 @@ def einsum( subscript: str, /, *operands: ArrayLike, out: None = None, - optimize: str | bool | list[tuple[int, ...]] = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "auto", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, @@ -9516,7 +9516,7 @@ def einsum( axes: Sequence[Any], /, *operands: ArrayLike | Sequence[Any], out: None = None, - optimize: str | bool | list[tuple[int, ...]] = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "auto", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, @@ -9528,7 +9528,7 @@ def einsum( subscripts, /, *operands, out: None = None, - optimize: str | bool | list[tuple[int, ...]] = "optimal", + optimize: str | bool | list[tuple[int, ...]] = "auto", precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, @@ -9548,10 +9548,10 @@ def einsum( subscripts: string containing axes names separated by commas. *operands: sequence of one or more arrays corresponding to the subscripts. optimize: specify how to optimize the order of computation. In JAX this defaults - to ``"optimal"`` which produces optimized expressions via the opt_einsum_ + to ``"auto"`` which produces optimized expressions via the opt_einsum_ package. Other options are ``True`` (same as ``"optimal"``), ``False`` (unoptimized), or any string supported by ``opt_einsum``, which - includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also + includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also be a pre-computed path (see :func:`~jax.numpy.einsum_path`). precision: either ``None`` (default), which means the default precision for the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,