Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix save_state_dict #645

Merged
merged 1 commit into from
Dec 12, 2023
Merged

Conversation

AMHermansen
Copy link
Collaborator

Current implementation moves the entire model to cpu, whenever save_state_dict is called. This seems like an undesireable side effect of the method. This PR changes save_state_dict to only save the statedict, but keeps the model on the current device.

Current implementation moves the entire model to cpu, whenever save_state_dict is called. This seems like an undesireable side effect of the method. This PR changes save_state_dict to only save the statedict, but keeps the model on the current device.
@RasmusOrsoe
Copy link
Collaborator

@AMHermansen thanks for this suggestion. I'm not sure we want this change though - If the state dict is saved to disk from gpu, it will require the model to be on gpu when the state dict is loaded in again, or an error will be thrown. see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html .

Has this been a issue for you?

@AMHermansen
Copy link
Collaborator Author

@AMHermansen thanks for this suggestion. I'm not sure we want this change though - If the state dict is saved to disk from gpu, it will require the model to be on gpu when the state dict is loaded in again, or an error will be thrown. see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html .

Has this been a issue for you?

The suggested change in this PR only removes the side effect from the current save_state_dict implementation to not move the model to cpu when this is called. This is done by copying the state_dict to cpu and then saving the copy. The reason for this implementation is to make saving models more streamlined, my current understanding from the example scripts is that save_model_config and save_state_dict is the intended way to save graphnet models. If you however want to save a model like this during training, you will run into problems, since the model will be moved away from the accelerator.

Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AMHermansen sorry, I glanced over this too quickly. I was under the impression that the state dict was saved on whatever device it happened to be on; upon looking at the code again I see that's not the case.

@AMHermansen AMHermansen merged commit d06882b into graphnet-team:main Dec 12, 2023
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants