Skip to content

Distributed Checkpointing #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

hlnchen
Copy link

@hlnchen hlnchen commented Jun 2, 2025

3 config added to control checkpoints:

  • checkpoint_dir: directory of the checkpoint, use gs://bucket/path/to/checkpoint to save to gcs bucket
  • resume_from_checkpoint:
    • if null, will not load checkpoint but load from huggingface pretrained weights
    • if a positive integer step, checkpoint manager will try to find and load weights from checkpoint under checkpoint_dir/resume_from_checkpoint/
    • if latest or the step not found by the manager, then last checkpoint will be loaded
  • save_steps: save frequency
  • checkpoint state dict
    • model
    • optimizer
    • lr_scheduler
    • step

after checkpoint loading will skip first step iterations by looping the dataloader.

What's not included:

  • async checkpointing
  • saving in Huggingface format

Haolin Chen added 2 commits May 29, 2025 16:49
Copy link

google-cla bot commented Jun 2, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@yaoshiang
Copy link
Collaborator

Hi Haolin Chen @hlnchen , thanks so much for starting this PR! There are a few things I'd ask to get started on this.

  • Please sign the contributor agreement. You should seek legal advice on this, particularly if these contributions are on behalf of your employer.
  • Please resolve merge conflicts. I know this is a moving target since main is constantly moving.
  • Please ensure linting rules are followed. We use ruff almost OOTB, so this hopefully won't be a serious lift.
  • Please include unit testing as well as manual performance testing to demonstrate that this is working as intended, particularly on larger models. I realize you used the torch_xla's distributed checkpointing functionality, however, I have looked at the tests of that in the past and to my knowledge, it was never tested on a large model. I think there's a high risk of an inadvertent all-reduce in there, which we can really only answer with performance checks and patching into the underlying tensors.

Thanks so much! Hopefully the above isn't too much work to get a contribution in, we'd love to have your support.

@vlasenkoalexey
Copy link
Collaborator

One more comment, please make sure that if checkpointing is not enabled, PR has no effect on the way models are currently trained. We can relax this limitation later, but for now let's play it safe to make sure that new functionality doesn't break anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants