diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index 40c09dc25..e24aebad0 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -18,6 +18,7 @@ from typing import Optional, Union from etils import epath + from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions from orbax.checkpoint.logging import abstract_logger, cloud_logger, standard_logger, composite_logger import jax @@ -134,21 +135,29 @@ def map_to_pspec(data): pspec = data.sharding.spec mesh = data.sharding.mesh if not enable_single_replica_ckpt_restoring: - return orbax.checkpoint.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec) - orbax.checkpoint.type_handlers.register_type_handler( - jax.Array, orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), override=True - ) - orbax.checkpoint.type_handlers.register_type_handler( - jax.Array, orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), override=True - ) - replica_axis_index = 0 # for maxtext data is the first dimension + return orbax.checkpoint.type_handlers.ArrayRestoreArgs( + mesh=mesh, mesh_axes=pspec) + replica_axis_index = 0 replica_devices = _replica_devices(mesh.devices, replica_axis_index) replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) - single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + single_replica_sharding = jax.sharding.NamedSharding( + replica_mesh, pspec) + + array_handler = ( + orbax.checkpoint.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=0, + broadcast_memory_limit_bytes=1024 * 1024 * 1000 # 1000 MB limit + ) + ) + orbax.checkpoint.type_handlers.register_type_handler( + jax.Array, + array_handler, + override=True + ) + return orbax.checkpoint.type_handlers.SingleReplicaArrayRestoreArgs( sharding=jax.sharding.NamedSharding(mesh, pspec), single_replica_sharding=single_replica_sharding, - replica_axis_index=replica_axis_index, global_shape=data.shape, dtype=data.dtype, )