-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix potential memory issues when use deepspeed Z3 #6726
Conversation
@microsoft-github-policy-service agree company=Intel |
1 similar comment
@microsoft-github-policy-service agree company=Intel |
@wenbinc-Bin, thanks for the PR. Are you able to provide unit tests for this? |
Ok, I try to add a unit test. |
"ds_grads_remaining" is used to triger post_backward_function(). If the module is called more than once in one training step, this variable will be initialized every time. If backward() is also called multiple times, "ds_grads_remaining" will be reduced to a negative number. post_backward_function() will not be called as expected. This leads to extra fetch operation or extra memory usage. Set "ds_grads_remaining" to 0 only when it is not initialized to fix this issue.` Signed-off-by: Wenbin Chen <[email protected]>
Fix a bug that after the first training step, allocated parameters may bigger than "__max_n_available_params". "__n_available_params" is set to 0 in reset_step() which is called in backward(). All parameter are released in release_and_reset_all() which is called in step(). "__n_available_params" is reduced when parameter is released. These mean if step() is called after backward(), "__n_available_params" will be reduced to a negative number. "__n_available_params" is used to restrict fetched parameters, so negative value leads to a problem that fetched parameter will be larger than upper bound ("__max_n_available_params"). Move "__n_available_params = 0" to release_and_reset_all() to fix this issue. Signed-off-by: Wenbin Chen <[email protected]>
If run model more than once in one training step, there may be issues. Add unit test to catch these kinds of problems. Signed-off-by: Wenbin Chen <[email protected]>
I saw unit-tests workflow failed. I fix the issue and update the PR. |
Signed-off-by: Wenbin Chen <[email protected]>
Head branch was pushed to by a user without write access
I had OOM problem when doing DPO training using zero3. It needs to call module twice in one training step, and second call is with no_grad().
The problem is caused by two bugs:
I tried to create two patches to fix these issues.