@@ -49,16 +49,19 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):
4949
5050 def __init__ (
5151 self ,
52- read_timeout : datetime .timedelta | None = None ,
52+ timeout : datetime .timedelta | None = None ,
5353 use_ocdbt : bool = False ,
5454 ):
55- """Constructor .
55+ """Orbax array handler for Pathways on Cloud with Persistence API .
5656
5757 Args:
58- read_timeout: Duration indicating the timeout for reading arrays
58+ timeout: Duration indicating the timeout for reading and writing arrays.
59+ Default is 1 hour.
5960 use_ocdbt: allows using Tensorstore OCDBT driver.
6061 """
61- self ._read_timeout = read_timeout
62+ if timeout is None :
63+ timeout = datetime .timedelta (hours = 1 )
64+ self .timeout = timeout
6265
6366 if use_ocdbt :
6467 raise ValueError ("OCDBT not supported for Pathways." )
@@ -92,7 +95,7 @@ async def serialize(
9295
9396 self ._wait_for_directory_creation_signals ()
9497 locations , names = extract_parent_dir_and_name (infos )
95- f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
98+ f = functools .partial (helper .write_one_array , timeout = self .timeout )
9699 futures_results = list (map (f , locations , names , values ))
97100
98101 return [
@@ -181,7 +184,7 @@ async def deserialize(
181184 grouped_global_shapes ,
182185 grouped_shardings ,
183186 global_mesh .devices ,
184- timeout = self ._read_timeout ,
187+ timeout = self .timeout ,
185188 )
186189 # each persistence call is awaited serially.
187190 read_future .result ()
@@ -191,7 +194,7 @@ async def deserialize(
191194
192195
193196def register_pathways_handlers (
194- read_timeout : datetime .timedelta | None = None ,
197+ timeout : datetime .timedelta | None = None ,
195198):
196199 """Function that must be called before saving or restoring with Pathways."""
197200 logger .debug (
@@ -200,7 +203,7 @@ def register_pathways_handlers(
200203 type_handlers .register_type_handler (
201204 jax .Array ,
202205 CloudPathwaysArrayHandler (
203- read_timeout = read_timeout ,
206+ timeout = timeout ,
204207 ),
205208 override = True ,
206209 )
0 commit comments