diff --git a/xla/python/pytree.cc b/xla/python/pytree.cc index 5592a96454821..65bfb3fe5305e 100644 --- a/xla/python/pytree.cc +++ b/xla/python/pytree.cc @@ -595,9 +595,15 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { case PyTreeKind::kNone: if (!object.is_none()) { - throw std::invalid_argument( - absl::StrFormat("Expected None, got %s.", - nb::cast(nb::repr(object)))); + PythonDeprecationWarning( + /*stacklevel=*/3, + "In a future release of JAX, flatten-up-to will no longer " + "consider None to be a tree-prefix of non-None values, got: " + "%s.\n\n" + "To preserve the current behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object))); } break; diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index a92abf945c4d4..ede0f7749b9b3 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 284 +_version = 283 # Version number for MLIR:Python components. mlir_api_version = 57