Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient clipping #5150

Merged
merged 2 commits into from
Feb 21, 2024
Merged

Fix gradient clipping #5150

merged 2 commits into from
Feb 21, 2024

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Feb 19, 2024

The gradient clipping API doesn't apply the coefficient correctly. This PR resolves the issue and adds a test case.

@tohtana tohtana marked this pull request as ready for review February 19, 2024 06:17
@tjruwase tjruwase added this pull request to the merge queue Feb 21, 2024
Merged via the queue into master with commit 005afe1 Feb 21, 2024
12 checks passed
@cloneofsimo
Copy link

cloneofsimo commented Feb 24, 2024

I'm actually fascinated how did this bug go unnoticed for so long, so LONG? isn't grad clip on by default? does this mean literally all of huggingface trainer, lightning etc that leverages Deepspeed backend was faulty, yet somehow everyone succeeded on training? like HOW? is this not-as-critical?

Just to be clear I'm huge fan / user of deepspeed and I am very glad this tool exists, I am just genuinely curious how this could've been not-so-impactful for so long (in terms of training dynamics or so) given it looks like it impacts all of engine.step() function, and all of my previous experience with deepspeed has been very good. Sorry if I sounded too passive aggressive

@tohtana
Copy link
Contributor Author

tohtana commented Feb 24, 2024

@cloneofsimo One reason is that this function is called only for limited cases. I noticed that this issue when I set zero_stage=0 and precision=fp32. I didn't see this issue when I changed zero_stage. Probably not many users have used DeepSpeed with this config.

@tohtana tohtana deleted the tohtana/fix_fp32_clipping branch February 24, 2024 06:58
@cloneofsimo
Copy link

cloneofsimo commented Feb 24, 2024

@tohtana thanks for clarification!

Edit: yeah @SeunghyunSEO , @tohtana is right, looks like other functions are used for zero, which is why people use this in the first place

@SeunghyunSEO
Copy link

@tohtana ty for your kind explanation. my understanding is that when zero3 (or 1,2) is activated, it has no problem because we use this function to clip and scale the gradient, right?

ShellyNR pushed a commit to ShellyNR/DeepSpeed that referenced this pull request Mar 11, 2024
The gradient clipping API doesn't apply the coefficient correctly. This
PR resolves the issue and adds a test case.

Co-authored-by: Logan Adams <[email protected]>
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
The gradient clipping API doesn't apply the coefficient correctly. This
PR resolves the issue and adds a test case.

Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants