From 90fda0f39b00e93aef5f03a96de07330bbfebee3 Mon Sep 17 00:00:00 2001 From: Haolin Chen Date: Thu, 29 May 2025 16:49:54 -0700 Subject: [PATCH 1/2] update --- .../torch_xla_models/configs/default.yaml | 5 ++ torchprime/torch_xla_models/train.py | 76 ++++++++++++++++++- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/torchprime/torch_xla_models/configs/default.yaml b/torchprime/torch_xla_models/configs/default.yaml index 3400aecf..ed2ca7c5 100644 --- a/torchprime/torch_xla_models/configs/default.yaml +++ b/torchprime/torch_xla_models/configs/default.yaml @@ -24,6 +24,11 @@ profile_duration: 100000 # This might be overwritten when using tp run to launch the run using XPK output_dir: outputs +# Checkpoint directory +checkpoint_dir: checkpoints/ +checkpoint_step: null +save_steps: 15 + optimizer: learning_rate: 5.e-5 lr_scheduler: diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 38299fc9..8cc0be4e 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -11,6 +11,7 @@ import datasets import hydra import torch +import torch.distributed as dist import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp @@ -18,6 +19,7 @@ import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr import transformers +import wandb from omegaconf import DictConfig, OmegaConf from torch import nn from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -30,11 +32,12 @@ get_scheduler, set_seed, ) +from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer from transformers.optimization import Adafactor from transformers.trainer_pt_utils import get_module_class_from_name from transformers.utils import check_min_version -from torchprime.data.dataset import make_huggingface_dataset +from torchprime.data.dataset import make_huggingface_dataset, make_gcs_dataset from torchprime.layers.sequential import HomogeneousSequential from torchprime.metrics.metrics import MetricsLogger from torchprime.metrics.mfu import compute_mfu @@ -57,6 +60,7 @@ xr.use_spmd() assert xr.is_spmd() is True +dist.init_process_group(backend='gloo', init_method='xla://') class Trainer: """The trainer.""" @@ -126,6 +130,11 @@ def __init__( num_training_steps=self.config.max_steps, ) + # Initialize checkpoint manager + self.ckpt_dir = config.checkpoint_dir + self.ckpt_mgr = CheckpointManager(path=self.ckpt_dir, save_interval=config.save_steps) + self.start_step = 0 + # Execute all initialization work queued so far before starting training. torch_xla.sync() @@ -137,6 +146,34 @@ def _prime_optimizer(self): self.optimizer.step() torch_xla.sync() + def _load_checkpoint(self): + """Load optimizer, scheduler, and training state from checkpoint.""" + tracked_steps = self.ckpt_mgr.all_steps() + if not tracked_steps: + logger.warning("No checkpoint steps found. Starting from scratch.") + return + self.optimizer = prime_optimizer(self.optimizer) # NOTE: needed to create the dummy state dict for the optimizer + state_dict = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.lr_scheduler.state_dict(), + "step": self.start_step, + } + if self.config.checkpoint_step in tracked_steps: + logger.info(f"Loading checkpoint from step {self.config.checkpoint_step}") + self.ckpt_mgr.restore(self.config.checkpoint_step, state_dict) + elif self.config.checkpoint_step == "latest": + last_step = max(tracked_steps) + logger.warning(f"Checkpoint step {self.config.checkpoint_step} not found in tracked steps {tracked_steps}. Loading from latest checkpoint {last_step}.") + self.ckpt_mgr.restore(last_step, state_dict) + else: + raise ValueError(f"Invalid checkpoint step: {self.config.checkpoint_step}. Must be one of {tracked_steps} or 'latest'.") + + self.model.load_state_dict(state_dict["model"]) + self.optimizer.load_state_dict(state_dict["optimizer"]) + self.lr_scheduler.load_state_dict(state_dict["scheduler"]) + self.start_step = state_dict["step"] + def _get_train_dataloader(self): if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -264,6 +301,8 @@ def _get_classes_by_names(self, model, activation_checkpoint_layers: list[str]): return tuple(classes_to_checkpoint) def train_loop(self): + if self.config.checkpoint_step is not None: + self._load_checkpoint() self.model.train() self.model.zero_grad() @@ -276,9 +315,24 @@ def train_loop(self): logger.info("Starting training") logger.info(f" Max step: {max_step}") logger.info(f" Global batch size: {self.global_batch_size}") - + if hasattr(self, 'start_step') and self.start_step > 0: + logger.info(f" Resuming from step: {self.start_step}") + # Initialize epoch and step counters, accounting for checkpoint loading epoch = 0 - for step in range(max_step): + start_step = self.start_step + + # Skip batches if we're resuming from a checkpoint + if start_step > 0: + logger.info(f"Skipping {start_step} batches to resume from checkpoint...") + for _ in range(start_step): + try: + next(train_iterator) + except StopIteration: + epoch += 1 + train_iterator = iter(train_loader) + next(train_iterator) + + for step in range(start_step, max_step): try: batch = next(train_iterator) except StopIteration: @@ -308,6 +362,22 @@ def step_closure(epoch, step, loss, trace_start_time, trace_end_time): run_async=True, ) + if step > self.start_step and step % self.config.save_steps == 0: + # NOTE: currently we save the checkpoint synchronously + xm.wait_device_ops() # Wait for all XLA operations to complete + state_dict = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.lr_scheduler.state_dict(), + "step": step, + } + try: + self.ckpt_mgr.save(step, state_dict, force=True) + logger.info(f"Checkpoint saved at step {step} to {self.ckpt_dir}") + except Exception as e: + logger.error(f"Failed to save checkpoint at step with ckpt_mgr {step}: {e}") + xm.wait_device_ops() # Ensure save is complete before logging + # Capture profile at the prefer step if step == self.config.profile_step: # Wait until device execution catches up to tracing before triggering the profile. This will From 8b94a028df6226f150a9830610c3b73d694d96ae Mon Sep 17 00:00:00 2001 From: Haolin Chen Date: Mon, 2 Jun 2025 15:31:38 -0700 Subject: [PATCH 2/2] clean up --- torchprime/torch_xla_models/configs/default.yaml | 2 +- torchprime/torch_xla_models/train.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torchprime/torch_xla_models/configs/default.yaml b/torchprime/torch_xla_models/configs/default.yaml index ed2ca7c5..7aca36d3 100644 --- a/torchprime/torch_xla_models/configs/default.yaml +++ b/torchprime/torch_xla_models/configs/default.yaml @@ -26,7 +26,7 @@ output_dir: outputs # Checkpoint directory checkpoint_dir: checkpoints/ -checkpoint_step: null +resume_from_checkpoint: null save_steps: 15 optimizer: diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 8cc0be4e..82edd29d 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -19,7 +19,6 @@ import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr import transformers -import wandb from omegaconf import DictConfig, OmegaConf from torch import nn from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -159,15 +158,15 @@ def _load_checkpoint(self): "scheduler": self.lr_scheduler.state_dict(), "step": self.start_step, } - if self.config.checkpoint_step in tracked_steps: - logger.info(f"Loading checkpoint from step {self.config.checkpoint_step}") - self.ckpt_mgr.restore(self.config.checkpoint_step, state_dict) - elif self.config.checkpoint_step == "latest": + if self.config.resume_from_checkpoint in tracked_steps: + logger.info(f"Loading checkpoint from step {self.config.resume_from_checkpoint}") + self.ckpt_mgr.restore(self.config.resume_from_checkpoint, state_dict) + elif self.config.resume_from_checkpoint == "latest": last_step = max(tracked_steps) - logger.warning(f"Checkpoint step {self.config.checkpoint_step} not found in tracked steps {tracked_steps}. Loading from latest checkpoint {last_step}.") + logger.warning(f"Checkpoint step {self.config.resume_from_checkpoint} not found in tracked steps {tracked_steps}. Loading from latest checkpoint {last_step}.") self.ckpt_mgr.restore(last_step, state_dict) else: - raise ValueError(f"Invalid checkpoint step: {self.config.checkpoint_step}. Must be one of {tracked_steps} or 'latest'.") + raise ValueError(f"Invalid checkpoint step: {self.config.resume_from_checkpoint}. Must be one of {tracked_steps} or 'latest'.") self.model.load_state_dict(state_dict["model"]) self.optimizer.load_state_dict(state_dict["optimizer"]) @@ -301,7 +300,7 @@ def _get_classes_by_names(self, model, activation_checkpoint_layers: list[str]): return tuple(classes_to_checkpoint) def train_loop(self): - if self.config.checkpoint_step is not None: + if self.config.resume_from_checkpoint is not None: self._load_checkpoint() self.model.train() self.model.zero_grad()