Fix problem with gradients accumulating #3
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The problem arises because a gradient accumulates in the layers with each new transformation.
Let's imagine, that original model works well with batch_size 128 (for example) and GPU was fully loaded, then wrapped model with several transform will crash for similar batch_size = 128, it will work only for batch size ~= 40.
To fix this problem I added
torch._no_grad()
.PS. If you now about this problem and assume that this code must be added in outer function, then you have to change your example snippet from readme because it doesn't include
no_grad()
. But in my opinion, it will be better left this code here, inside the forward.