Skip to content

Commit c9fb204

Browse files
lukebaumanncopybara-github
authored andcommitted
Handle PRNG keys in reshard by using jax.device_put.
The `reshard` function now treats `jax.Array` instances that are PRNG keys as non-reshardable via the experimental reshard sidechannel API, directing them through `jax.device_put` instead of the custom device-to-device transfer logic. This ensures correct handling of PRNG keys when resharding PyTrees. A new test case is added to confirm proper resharding of PyTrees containing random keys. PiperOrigin-RevId: 814746994
1 parent b72729b commit c9fb204

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

pathwaysutils/experimental/reshard.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,24 +150,27 @@ def reshard(
150150
# put them back together in the right order.
151151
array_info_lambda = lambda: {"arrays": [], "indices": [], "dst_shardings": []}
152152
jax_arrays = collections.defaultdict(array_info_lambda)
153-
non_jax_arrays = array_info_lambda()
153+
non_reshardable_arrays = array_info_lambda()
154154
for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)):
155155
if not isinstance(dst_sharding, jax.sharding.Sharding):
156156
raise ValueError("`sharding` must contain only `jax.sharding.Sharding`")
157-
if isinstance(arr, jax.Array):
157+
if not isinstance(arr, jax.Array) or (
158+
hasattr(arr, "dtype")
159+
and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
160+
):
161+
non_reshardable_arrays["arrays"].append(arr)
162+
non_reshardable_arrays["indices"].append(index)
163+
non_reshardable_arrays["dst_shardings"].append(dst_sharding)
164+
else:
158165
device_set = frozenset(arr.sharding.device_set)
159166
jax_arrays[device_set]["arrays"].append(arr)
160167
jax_arrays[device_set]["indices"].append(index)
161168
jax_arrays[device_set]["dst_shardings"].append(dst_sharding)
162-
else:
163-
non_jax_arrays["arrays"].append(arr)
164-
non_jax_arrays["indices"].append(index)
165-
non_jax_arrays["dst_shardings"].append(dst_sharding)
166-
167-
if non_jax_arrays["arrays"]:
168-
non_jax_arrays["arrays"] = jax.device_put(
169-
non_jax_arrays["arrays"],
170-
non_jax_arrays["dst_shardings"],
169+
170+
if non_reshardable_arrays["arrays"]:
171+
non_reshardable_arrays["arrays"] = jax.device_put(
172+
non_reshardable_arrays["arrays"],
173+
non_reshardable_arrays["dst_shardings"],
171174
donate=donate,
172175
may_alias=may_alias,
173176
)
@@ -186,7 +189,9 @@ def reshard(
186189
).execute(tuple(array_info["arrays"]))
187190

188191
result = [None] * len(flat_x)
189-
for arr, idx in zip(non_jax_arrays["arrays"], non_jax_arrays["indices"]):
192+
for arr, idx in zip(
193+
non_reshardable_arrays["arrays"], non_reshardable_arrays["indices"]
194+
):
190195
result[idx] = arr
191196
for array_info in jax_arrays.values():
192197
for arr, idx in zip(array_info["arrays"], array_info["indices"]):

0 commit comments

Comments
 (0)