You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to understand gradient checkpointing and found your explanation in gradient-checkpointing-nin.ipynb very helpful. I cloned the repo and tried rerunning the experiments. However, I was unable to reproduce the result mentioned in your conclusion.
When I run the notebook, for the vanilla NiN, my memory consumption (current, peak) are 413527 and 154049604, with runtime 109.1s.
For the checkpointed version (segments=1) of the model, the memory consumption are 402938 and 154064699, with runtime 110.14s.
From these tests, I was not able to observe a significant improvement in memory as the notebook states (22% memory improvement with 14% runtime sacrifice).
I've tried running with multiple seeds and checkpoint segment sizes, and was not able to see a significant memory improvement either.
I'm not sure why this is and could need a bit of help. Could this be due to the size of the network is relatively small and the effects are less obvious? Or could it be the checkpointing implementation from PyTorch has changed over the years? I would appreciate it if you could provide any insight in this.
The text was updated successfully, but these errors were encountered:
Hello,
I am trying to understand gradient checkpointing and found your explanation in gradient-checkpointing-nin.ipynb very helpful. I cloned the repo and tried rerunning the experiments. However, I was unable to reproduce the result mentioned in your conclusion.
When I run the notebook, for the vanilla NiN, my memory consumption (current, peak) are 413527 and 154049604, with runtime 109.1s.
For the checkpointed version (segments=1) of the model, the memory consumption are 402938 and 154064699, with runtime 110.14s.
From these tests, I was not able to observe a significant improvement in memory as the notebook states (22% memory improvement with 14% runtime sacrifice).
I've tried running with multiple seeds and checkpoint segment sizes, and was not able to see a significant memory improvement either.
I'm not sure why this is and could need a bit of help. Could this be due to the size of the network is relatively small and the effects are less obvious? Or could it be the checkpointing implementation from PyTorch has changed over the years? I would appreciate it if you could provide any insight in this.
The text was updated successfully, but these errors were encountered: