Skip to content

Commit

Permalink
Merge pull request #871 from google:blured_abhinav
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672723485
  • Loading branch information
maxtext authors committed Sep 10, 2024
2 parents 2706dc4 + b851912 commit 66684a2
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
import sys
from etils import epath
import functools
import time

from typing import Sequence
from typing import Sequence, Optional
from absl import app
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
Expand Down Expand Up @@ -166,8 +167,30 @@ def clear_buffered_metrics():
_buffered_step = None
_buffered_metrics = None

def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None):
"""Wrapper for saving checkpoint"""
def save_checkpoint(
checkpoint_manager,
step,
state,
dataset_type="c4",
data_iterator=None,
config: Optional[pyconfig.config] = None,
) -> bool:
"""Wrapper for saving checkpoint."""
if config and config.enable_checkpointing:
if (step % config.checkpoint_period == 0) or (
config.enable_emergency_checkpoint
and step % config.local_checkpoint_period == 0
):
blocking_until_ready_start = time.time()
max_logging.log(f"Waiting for step {step} to finish before checkpoint...")
# We block here on the step finishing so that our checkpointing metrics
# measure only checkpointing time, not training time.
jax.block_until_ready(state)
max_logging.log(
f"Waited {time.time() - blocking_until_ready_start} seconds for step "
f"{step} to finish before starting checkpointing."
)

# specify chunk_byte_size to force orbax to control maximum file size in checkpoint
save_args = jax.tree.map(
lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=_CHUNK_BYTE_SIZE), state
Expand Down Expand Up @@ -617,7 +640,7 @@ def train_loop(config, state=None):
last_step_completion = new_time

if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, int(step), state, config.dataset_type, data_iterator):
if save_checkpoint(checkpoint_manager, int(step), state, config.dataset_type, data_iterator, config):
max_logging.log(f"saved a checkpoint at step {step}")

# Upon preemption, exit when and only when all ongoing saves are complete.
Expand Down

0 comments on commit 66684a2

Please sign in to comment.