Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load state dicts to CPU #328

Merged
merged 3 commits into from
Oct 12, 2023
Merged

Load state dicts to CPU #328

merged 3 commits into from
Oct 12, 2023

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Oct 12, 2023

It turns out we can load (legacy) sharded and unsharded checkpoints to CPU via torch.load() since FSDP.load_state_dict() will copy tensors to the right device anyway.
This should save some GPU memory.

load_fsdp_optim_state(self.fsdp_model, self.optim, optim_state_dict)

# Load other state.
try:
train_state_dict = torch.load(resource_path(load_path, "train.pt", local_cache=local_cache))
train_state_dict = load_state_dict(load_path, "train.pt", local_cache=local_cache)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why no map_location="cpu" here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only one tiny tensor in the trainer state - the GPU RNG state. Which needs to go on GPU anyway.

@@ -218,7 +218,13 @@ def save_state_dict(
upload(target_path, upload_target, save_overwrite=save_overwrite)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need a unit test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our trainer unit tests are severely lacking at the moment, but I just tested this on spare MosaicML nodes are we recover the same loss exactly after these changes:
image

@epwalsh
Copy link
Member Author

epwalsh commented Oct 12, 2023

@ibeltagy if this doesn't avoid OOM with our LUMI runs we will have to set --fsdp.wrapping_strategy=by_block.

@epwalsh epwalsh merged commit 809fe9d into main Oct 12, 2023
10 checks passed
@epwalsh epwalsh deleted the petew/load-state-dict branch October 12, 2023 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants