Skip to content

Distributed Checkpointing #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchprime/torch_xla_models/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ profile_duration: 100000
# This might be overwritten when using tp run to launch the run using XPK
output_dir: outputs

# Checkpoint directory
checkpoint_dir: checkpoints/
resume_from_checkpoint: null
save_steps: 15

torch_dtype: bfloat16
optimizer:
type: adafactor
Expand Down
75 changes: 72 additions & 3 deletions torchprime/torch_xla_models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import datasets
import hydra
import torch
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
Expand All @@ -30,11 +31,12 @@
get_scheduler,
set_seed,
)
from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer
from transformers.optimization import Adafactor
from transformers.trainer_pt_utils import get_module_class_from_name
from transformers.utils import check_min_version

from torchprime.data.dataset import make_huggingface_dataset
from torchprime.data.dataset import make_huggingface_dataset, make_gcs_dataset
from torchprime.layers.sequential import HomogeneousSequential
from torchprime.metrics.metrics import MetricsLogger
from torchprime.metrics.mfu import compute_mfu
Expand All @@ -57,6 +59,7 @@
xr.use_spmd()
assert xr.is_spmd() is True

dist.init_process_group(backend='gloo', init_method='xla://')

class Trainer:
"""The trainer."""
Expand Down Expand Up @@ -130,6 +133,11 @@ def __init__(
num_training_steps=self.config.max_steps,
)

# Initialize checkpoint manager
self.ckpt_dir = config.checkpoint_dir
self.ckpt_mgr = CheckpointManager(path=self.ckpt_dir, save_interval=config.save_steps)
self.start_step = 0

# Execute all initialization work queued so far before starting training.
torch_xla.sync()

Expand All @@ -141,6 +149,34 @@ def _prime_optimizer(self):
self.optimizer.step()
torch_xla.sync()

def _load_checkpoint(self):
"""Load optimizer, scheduler, and training state from checkpoint."""
tracked_steps = self.ckpt_mgr.all_steps()
if not tracked_steps:
logger.warning("No checkpoint steps found. Starting from scratch.")
return
self.optimizer = prime_optimizer(self.optimizer) # NOTE: needed to create the dummy state dict for the optimizer
state_dict = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.lr_scheduler.state_dict(),
"step": self.start_step,
}
if self.config.resume_from_checkpoint in tracked_steps:
logger.info(f"Loading checkpoint from step {self.config.resume_from_checkpoint}")
self.ckpt_mgr.restore(self.config.resume_from_checkpoint, state_dict)
elif self.config.resume_from_checkpoint == "latest":
last_step = max(tracked_steps)
logger.warning(f"Checkpoint step {self.config.resume_from_checkpoint} not found in tracked steps {tracked_steps}. Loading from latest checkpoint {last_step}.")
self.ckpt_mgr.restore(last_step, state_dict)
else:
raise ValueError(f"Invalid checkpoint step: {self.config.resume_from_checkpoint}. Must be one of {tracked_steps} or 'latest'.")

self.model.load_state_dict(state_dict["model"])
self.optimizer.load_state_dict(state_dict["optimizer"])
self.lr_scheduler.load_state_dict(state_dict["scheduler"])
self.start_step = state_dict["step"]

def _get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
Expand Down Expand Up @@ -268,6 +304,8 @@ def _get_classes_by_names(self, model, activation_checkpoint_layers: list[str]):
return tuple(classes_to_checkpoint)

def train_loop(self, metrics_logger):
if self.config.resume_from_checkpoint is not None:
self._load_checkpoint()
self.model.train()
self.model.zero_grad()

Expand All @@ -279,9 +317,24 @@ def train_loop(self, metrics_logger):
logger.info("Starting training")
logger.info(f" Max step: {max_step}")
logger.info(f" Global batch size: {self.global_batch_size}")

if hasattr(self, 'start_step') and self.start_step > 0:
logger.info(f" Resuming from step: {self.start_step}")
# Initialize epoch and step counters, accounting for checkpoint loading
epoch = 0
for step in range(max_step):
start_step = self.start_step

# Skip batches if we're resuming from a checkpoint
if start_step > 0:
logger.info(f"Skipping {start_step} batches to resume from checkpoint...")
for _ in range(start_step):
try:
next(train_iterator)
except StopIteration:
epoch += 1
train_iterator = iter(train_loader)
next(train_iterator)

for step in range(start_step, max_step):
try:
batch = next(train_iterator)
except StopIteration:
Expand Down Expand Up @@ -311,6 +364,22 @@ def step_closure(epoch, step, loss, trace_start_time, trace_end_time):
run_async=True,
)

if step > self.start_step and step % self.config.save_steps == 0:
# NOTE: currently we save the checkpoint synchronously
xm.wait_device_ops() # Wait for all XLA operations to complete
state_dict = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.lr_scheduler.state_dict(),
"step": step,
}
try:
self.ckpt_mgr.save(step, state_dict, force=True)
logger.info(f"Checkpoint saved at step {step} to {self.ckpt_dir}")
except Exception as e:
logger.error(f"Failed to save checkpoint at step with ckpt_mgr {step}: {e}")
xm.wait_device_ops() # Ensure save is complete before logging

# Capture profile at the prefer step
if step == self.config.profile_step:
# Wait until device execution catches up to tracing before triggering the profile. This will
Expand Down
Loading