You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Handle PRNG keys in reshard by using jax.device_put.
The `reshard` function now treats `jax.Array` instances that are PRNG keys as non-reshardable via the experimental reshard sidechannel API, directing them through `jax.device_put` instead of the custom device-to-device transfer logic. This ensures correct handling of PRNG keys when resharding PyTrees. A new test case is added to confirm proper resharding of PyTrees containing random keys.
PiperOrigin-RevId: 814746994
0 commit comments