1414"""TypeHandlers supporting Pathways backend."""
1515
1616import collections
17- from collections .abc import Sequence
17+ from collections .abc import Coroutine , Sequence
1818import concurrent .futures
1919import datetime
2020import functools
2121import logging
22- import typing
22+ from typing import Any , cast
2323
2424import jax
2525from orbax .checkpoint import future
2626from orbax .checkpoint import type_handlers
27+ from orbax .checkpoint ._src .metadata import array_metadata as array_metadata_lib
28+ from orbax .checkpoint ._src .metadata import array_metadata_store as array_metadata_store_lib
2729from pathwaysutils .persistence import helper
2830
2931
3335SaveArgs = type_handlers .SaveArgs
3436RestoreArgs = type_handlers .RestoreArgs
3537ArrayRestoreArgs = type_handlers .ArrayRestoreArgs
38+ ArrayMetadata = array_metadata_lib .ArrayMetadata
3639
3740
3841def extract_parent_dir_and_name (
@@ -51,26 +54,33 @@ def __init__(
5154 self ,
5255 timeout : datetime .timedelta | None = None ,
5356 use_ocdbt : bool = False ,
57+ array_metadata_store : array_metadata_store_lib .Store | None = None ,
5458 ):
5559 """Orbax array handler for Pathways on Cloud with Persistence API.
5660
5761 Args:
5862 timeout: Duration indicating the timeout for reading and writing arrays.
5963 Default is 1 hour.
6064 use_ocdbt: allows using Tensorstore OCDBT driver.
65+ array_metadata_store: An optional store for writing and reading array
66+ metadata. Only required for saving new-style jax random keys.
6167 """
6268 if timeout is None :
6369 timeout = datetime .timedelta (hours = 1 )
6470 self .timeout = timeout
6571
6672 if use_ocdbt :
6773 raise ValueError ("OCDBT not supported for Pathways." )
68- super ().__init__ ()
74+ super ().__init__ (array_metadata_store = array_metadata_store )
6975
7076 async def _background_serialize (
7177 self ,
7278 futures_results : Sequence [concurrent .futures .Future [None ]],
79+ metadata_coroutine : Coroutine [Any , Any , None ] | None = None ,
7380 ) -> None :
81+ if metadata_coroutine :
82+ await metadata_coroutine
83+
7484 for future_result in futures_results :
7585 future_result .result ()
7686
@@ -86,21 +96,60 @@ async def serialize(
8696 values : Sequence [jax .Array ],
8797 infos : Sequence [ParamInfo ],
8898 args : Sequence [SaveArgs ] | None = None ,
89- ) -> Sequence [future .Future ]:
99+ ) -> list [future .Future ]:
90100 """Uses Pathways Persistence API to serialize a jax array."""
91101 type_handlers .check_input_arguments (values , infos , args )
92102
93103 if any ([arg .dtype is not None for arg in args ]):
94104 raise ValueError ("Casting during save not supported for Pathways." )
95105
106+ array_metadatas = []
107+ any_random_key = False
108+ arrays = []
109+ for v , info , arg in zip (values , infos , args ):
110+ ext_metadata = None
111+ if jax .dtypes .issubdtype (v .dtype , jax .dtypes .prng_key ):
112+ any_random_key = True
113+ impl = str (jax .random .key_impl (v ))
114+ v = jax .random .key_data (v )
115+ ext_metadata = {array_metadata_lib .RANDOM_KEY_IMPL : impl }
116+
117+ array_metadatas .append (
118+ ArrayMetadata (
119+ param_name = info .name ,
120+ shape = v .shape ,
121+ dtype = (arg .dtype if arg is not None else v .dtype ),
122+ write_shape = getattr (v , "local_shape" , v .shape ),
123+ chunk_shape = getattr (v , "local_shape" , v .shape ),
124+ use_ocdbt = False ,
125+ use_zarr3 = False ,
126+ ext_metadata = ext_metadata ,
127+ )
128+ )
129+ arrays .append (v )
130+
131+ metadata_coroutine = None
132+ if any_random_key :
133+ if self ._array_metadata_store is None :
134+ raise ValueError (
135+ "Array metadata store is not set with a checkpoint that requires"
136+ f" it. Array metadata: { array_metadatas } "
137+ )
138+
139+ metadata_coroutine = self ._array_metadata_store .write (
140+ checkpoint_dir = infos [0 ].parent_dir ,
141+ array_metadatas = array_metadatas ,
142+ process_index = 0 ,
143+ )
144+
96145 self ._wait_for_directory_creation_signals ()
97146 locations , names = extract_parent_dir_and_name (infos )
98147 f = functools .partial (helper .write_one_array , timeout = self .timeout )
99- futures_results = list (map (f , locations , names , values ))
148+ futures_results = list (map (f , locations , names , arrays ))
100149
101150 return [
102151 future .CommitFutureAwaitingContractedSignals (
103- self ._background_serialize (futures_results ),
152+ self ._background_serialize (futures_results , metadata_coroutine ),
104153 name = "cloud_pathways_array_handler" ,
105154 )
106155 ]
@@ -109,7 +158,7 @@ async def deserialize(
109158 self ,
110159 infos : Sequence [ParamInfo ],
111160 args : Sequence [RestoreArgs ] | None = None ,
112- ) -> Sequence [jax .Array ]:
161+ ) -> list [jax .Array ]:
113162 """Uses Pathways Persistence API to deserialize a jax array."""
114163 if args is None :
115164 raise ValueError ("Must provide ArrayRestoreArgs to restore as jax.Array." )
@@ -128,7 +177,7 @@ async def deserialize(
128177 "To restore jax.Array, provide ArrayRestoreArgs; found"
129178 f" { type (arg ).__name__ } "
130179 )
131- arg = typing . cast (ArrayRestoreArgs , arg )
180+ arg = cast (ArrayRestoreArgs , arg )
132181 if arg .sharding is None and (arg .mesh is None or arg .mesh_axes is None ):
133182 raise ValueError (
134183 "Sharding of jax.Array cannot be None. Provide `mesh`"
@@ -143,7 +192,7 @@ async def deserialize(
143192 else :
144193 if not isinstance (arg .sharding , jax .sharding .NamedSharding ):
145194 raise ValueError ("Pathways only supports jax.sharding.NamedSharding." )
146- sharding = typing . cast (jax .sharding .NamedSharding , arg .sharding )
195+ sharding = cast (jax .sharding .NamedSharding , arg .sharding )
147196 global_meshes .append (sharding .mesh )
148197 mesh_axes .append (sharding .spec )
149198 shardings .append (sharding )
@@ -163,13 +212,30 @@ async def deserialize(
163212 ]
164213 dtypes = [m .dtype if d is None else d for m , d in zip (metadatas , dtypes )]
165214
215+ array_metadatas_cache = {}
216+ if self ._array_metadata_store is not None :
217+ if array_metadatas := await self ._array_metadata_store .read (
218+ checkpoint_dir = infos [0 ].parent_dir ,
219+ process_index = 0 ,
220+ ):
221+ if not isinstance (array_metadatas , list ):
222+ raise ValueError (
223+ "Array metadata store returned unexpected result:"
224+ f" { array_metadatas } "
225+ )
226+
227+ array_metadatas_cache = {
228+ array_metadata .param_name : array_metadata
229+ for array_metadata in array_metadatas
230+ }
231+
166232 # Group inputs by global_mesh so that we can perform batched Array
167233 # construction for each global_mesh.
168234 inputs_by_global_mesh = collections .defaultdict (list )
169235 for i , global_mesh in enumerate (global_meshes ):
170236 inputs_by_global_mesh [global_mesh ].append (i )
171237
172- results = [ None ] * len (infos )
238+ results = cast ( list [ jax . Array ], [ None ] * len (infos ) )
173239
174240 for global_mesh , idxs in inputs_by_global_mesh .items ():
175241 grouped_infos = [infos [idx ] for idx in idxs ]
@@ -188,13 +254,29 @@ async def deserialize(
188254 )
189255 # each persistence call is awaited serially.
190256 read_future .result ()
191- for idx , arr in zip (idxs , grouped_arrays ):
257+ for idx , info , arr in zip (idxs , grouped_infos , grouped_arrays ):
258+ if meta := array_metadatas_cache .get (info .name ):
259+ assert isinstance (
260+ meta , array_metadata_lib .SerializedArrayMetadata
261+ ), f"Expecting SerializedArrayMetadata but got { type (meta )} ."
262+ if meta .ext_metadata :
263+ assert isinstance (meta .ext_metadata , dict ), (
264+ "Expecting ext_metadata to be a dict but got"
265+ f" { type (meta .ext_metadata )} ."
266+ )
267+
268+ if impl := meta .ext_metadata .get (
269+ array_metadata_lib .RANDOM_KEY_IMPL
270+ ):
271+ arr = jax .random .wrap_key_data (arr , impl = impl )
192272 results [idx ] = arr
193- return results # pytype: disable=bad-return-type
273+
274+ return results
194275
195276
196277def register_pathways_handlers (
197278 timeout : datetime .timedelta | None = None ,
279+ array_metadata_store : array_metadata_store_lib .Store | None = None ,
198280):
199281 """Function that must be called before saving or restoring with Pathways."""
200282 logger .debug (
@@ -204,6 +286,7 @@ def register_pathways_handlers(
204286 jax .Array ,
205287 CloudPathwaysArrayHandler (
206288 timeout = timeout ,
289+ array_metadata_store = array_metadata_store ,
207290 ),
208291 override = True ,
209292 )
0 commit comments