Skip to content

Commit 32a3b47

Browse files
lukebaumanncopybara-github
authored andcommitted
Support jax.random.PRNGKey serialization in Pathways Orbax handler.
This change allows `CloudPathwaysArrayHandler` to correctly save and restore `jax.random.PRNGKey` objects by extracting and wrapping the key data, and storing metadata about the key implementation using an `ArrayMetadataStore`. This change introduces a dependency on Orbax's internal API. PiperOrigin-RevId: 816401473
1 parent c9fb204 commit 32a3b47

File tree

2 files changed

+97
-12
lines changed

2 files changed

+97
-12
lines changed

pathwaysutils/_initialize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919

2020
import jax
21+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
2122
from pathwaysutils import profiling
2223
from pathwaysutils import proxy_backend
2324
from pathwaysutils.persistence import orbax_handler
@@ -94,6 +95,7 @@ def initialize() -> None:
9495
if _is_persistence_enabled():
9596
orbax_handler.register_pathways_handlers(
9697
timeout=datetime.timedelta(hours=1),
98+
array_metadata_store=array_metadata_store_lib.Store(),
9799
)
98100

99101
# Turn off JAX compilation cache because Pathways handles its own

pathwaysutils/persistence/orbax_handler.py

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@
1414
"""TypeHandlers supporting Pathways backend."""
1515

1616
import collections
17-
from collections.abc import Sequence
17+
from collections.abc import Coroutine, Sequence
1818
import concurrent.futures
1919
import datetime
2020
import functools
2121
import logging
22-
import typing
22+
from typing import Any, cast
2323

2424
import jax
2525
from orbax.checkpoint import future
2626
from orbax.checkpoint import type_handlers
27+
from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib
28+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
2729
from pathwaysutils.persistence import helper
2830

2931

@@ -33,6 +35,7 @@
3335
SaveArgs = type_handlers.SaveArgs
3436
RestoreArgs = type_handlers.RestoreArgs
3537
ArrayRestoreArgs = type_handlers.ArrayRestoreArgs
38+
ArrayMetadata = array_metadata_lib.ArrayMetadata
3639

3740

3841
def extract_parent_dir_and_name(
@@ -51,26 +54,33 @@ def __init__(
5154
self,
5255
timeout: datetime.timedelta | None = None,
5356
use_ocdbt: bool = False,
57+
array_metadata_store: array_metadata_store_lib.Store | None = None,
5458
):
5559
"""Orbax array handler for Pathways on Cloud with Persistence API.
5660
5761
Args:
5862
timeout: Duration indicating the timeout for reading and writing arrays.
5963
Default is 1 hour.
6064
use_ocdbt: allows using Tensorstore OCDBT driver.
65+
array_metadata_store: An optional store for writing and reading array
66+
metadata. Only required for saving new-style jax random keys.
6167
"""
6268
if timeout is None:
6369
timeout = datetime.timedelta(hours=1)
6470
self.timeout = timeout
6571

6672
if use_ocdbt:
6773
raise ValueError("OCDBT not supported for Pathways.")
68-
super().__init__()
74+
super().__init__(array_metadata_store=array_metadata_store)
6975

7076
async def _background_serialize(
7177
self,
7278
futures_results: Sequence[concurrent.futures.Future[None]],
79+
metadata_coroutine: Coroutine[Any, Any, None] | None = None,
7380
) -> None:
81+
if metadata_coroutine:
82+
await metadata_coroutine
83+
7484
for future_result in futures_results:
7585
future_result.result()
7686

@@ -86,21 +96,60 @@ async def serialize(
8696
values: Sequence[jax.Array],
8797
infos: Sequence[ParamInfo],
8898
args: Sequence[SaveArgs] | None = None,
89-
) -> Sequence[future.Future]:
99+
) -> list[future.Future]:
90100
"""Uses Pathways Persistence API to serialize a jax array."""
91101
type_handlers.check_input_arguments(values, infos, args)
92102

93103
if any([arg.dtype is not None for arg in args]):
94104
raise ValueError("Casting during save not supported for Pathways.")
95105

106+
array_metadatas = []
107+
any_random_key = False
108+
arrays = []
109+
for v, info, arg in zip(values, infos, args):
110+
ext_metadata = None
111+
if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
112+
any_random_key = True
113+
impl = str(jax.random.key_impl(v))
114+
v = jax.random.key_data(v)
115+
ext_metadata = {array_metadata_lib.RANDOM_KEY_IMPL: impl}
116+
117+
array_metadatas.append(
118+
ArrayMetadata(
119+
param_name=info.name,
120+
shape=v.shape,
121+
dtype=(arg.dtype if arg is not None else v.dtype),
122+
write_shape=getattr(v, "local_shape", v.shape),
123+
chunk_shape=getattr(v, "local_shape", v.shape),
124+
use_ocdbt=False,
125+
use_zarr3=False,
126+
ext_metadata=ext_metadata,
127+
)
128+
)
129+
arrays.append(v)
130+
131+
metadata_coroutine = None
132+
if any_random_key:
133+
if self._array_metadata_store is None:
134+
raise ValueError(
135+
"Array metadata store is not set with a checkpoint that requires"
136+
f" it. Array metadata: {array_metadatas}"
137+
)
138+
139+
metadata_coroutine = self._array_metadata_store.write(
140+
checkpoint_dir=infos[0].parent_dir,
141+
array_metadatas=array_metadatas,
142+
process_index=0,
143+
)
144+
96145
self._wait_for_directory_creation_signals()
97146
locations, names = extract_parent_dir_and_name(infos)
98147
f = functools.partial(helper.write_one_array, timeout=self.timeout)
99-
futures_results = list(map(f, locations, names, values))
148+
futures_results = list(map(f, locations, names, arrays))
100149

101150
return [
102151
future.CommitFutureAwaitingContractedSignals(
103-
self._background_serialize(futures_results),
152+
self._background_serialize(futures_results, metadata_coroutine),
104153
name="cloud_pathways_array_handler",
105154
)
106155
]
@@ -109,7 +158,7 @@ async def deserialize(
109158
self,
110159
infos: Sequence[ParamInfo],
111160
args: Sequence[RestoreArgs] | None = None,
112-
) -> Sequence[jax.Array]:
161+
) -> list[jax.Array]:
113162
"""Uses Pathways Persistence API to deserialize a jax array."""
114163
if args is None:
115164
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")
@@ -128,7 +177,7 @@ async def deserialize(
128177
"To restore jax.Array, provide ArrayRestoreArgs; found"
129178
f" {type(arg).__name__}"
130179
)
131-
arg = typing.cast(ArrayRestoreArgs, arg)
180+
arg = cast(ArrayRestoreArgs, arg)
132181
if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None):
133182
raise ValueError(
134183
"Sharding of jax.Array cannot be None. Provide `mesh`"
@@ -143,7 +192,7 @@ async def deserialize(
143192
else:
144193
if not isinstance(arg.sharding, jax.sharding.NamedSharding):
145194
raise ValueError("Pathways only supports jax.sharding.NamedSharding.")
146-
sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding)
195+
sharding = cast(jax.sharding.NamedSharding, arg.sharding)
147196
global_meshes.append(sharding.mesh)
148197
mesh_axes.append(sharding.spec)
149198
shardings.append(sharding)
@@ -163,13 +212,30 @@ async def deserialize(
163212
]
164213
dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)]
165214

215+
array_metadatas_cache = {}
216+
if self._array_metadata_store is not None:
217+
if array_metadatas := await self._array_metadata_store.read(
218+
checkpoint_dir=infos[0].parent_dir,
219+
process_index=0,
220+
):
221+
if not isinstance(array_metadatas, list):
222+
raise ValueError(
223+
"Array metadata store returned unexpected result:"
224+
f" {array_metadatas}"
225+
)
226+
227+
array_metadatas_cache = {
228+
array_metadata.param_name: array_metadata
229+
for array_metadata in array_metadatas
230+
}
231+
166232
# Group inputs by global_mesh so that we can perform batched Array
167233
# construction for each global_mesh.
168234
inputs_by_global_mesh = collections.defaultdict(list)
169235
for i, global_mesh in enumerate(global_meshes):
170236
inputs_by_global_mesh[global_mesh].append(i)
171237

172-
results = [None] * len(infos)
238+
results = cast(list[jax.Array], [None] * len(infos))
173239

174240
for global_mesh, idxs in inputs_by_global_mesh.items():
175241
grouped_infos = [infos[idx] for idx in idxs]
@@ -188,13 +254,29 @@ async def deserialize(
188254
)
189255
# each persistence call is awaited serially.
190256
read_future.result()
191-
for idx, arr in zip(idxs, grouped_arrays):
257+
for idx, info, arr in zip(idxs, grouped_infos, grouped_arrays):
258+
if meta := array_metadatas_cache.get(info.name):
259+
assert isinstance(
260+
meta, array_metadata_lib.SerializedArrayMetadata
261+
), f"Expecting SerializedArrayMetadata but got {type(meta)}."
262+
if meta.ext_metadata:
263+
assert isinstance(meta.ext_metadata, dict), (
264+
"Expecting ext_metadata to be a dict but got"
265+
f" {type(meta.ext_metadata)}."
266+
)
267+
268+
if impl := meta.ext_metadata.get(
269+
array_metadata_lib.RANDOM_KEY_IMPL
270+
):
271+
arr = jax.random.wrap_key_data(arr, impl=impl)
192272
results[idx] = arr
193-
return results # pytype: disable=bad-return-type
273+
274+
return results
194275

195276

196277
def register_pathways_handlers(
197278
timeout: datetime.timedelta | None = None,
279+
array_metadata_store: array_metadata_store_lib.Store | None = None,
198280
):
199281
"""Function that must be called before saving or restoring with Pathways."""
200282
logger.debug(
@@ -204,6 +286,7 @@ def register_pathways_handlers(
204286
jax.Array,
205287
CloudPathwaysArrayHandler(
206288
timeout=timeout,
289+
array_metadata_store=array_metadata_store,
207290
),
208291
override=True,
209292
)

0 commit comments

Comments
 (0)