Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DEFAULT_MASK_VALUE causes gradient explosion and nan loss on deep models #614

Open
logicchains opened this issue Apr 23, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@logicchains
Copy link

I was training a llama model on GPU, with a custom embedding. It worked fine with 12 layers, dim 1024, seq length 256, but loss would become nan after the first step if setting num_layers to more than 17. I debugged the gradients, and found after each layer their magnitude would increase by around 100x, until they hit float32_max at around the 18th layer and became inf, leading to nan loss.

The gradient explosion seemed to be coming from
local_exps = jnp.exp(attn_weights - local_max)
in attentions.py.

Changing

DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
to
DEFAULT_MASK_VALUE = -jnp.inf
fixed the issue, and the gradients' magnitude stopped increasing after each level.

Presumably the issue wasn't noticed during TPU training as that uses a separate codepath.

@rwitten
Copy link
Collaborator

rwitten commented Apr 30, 2024

@logicchains thanks for the tips on GPU convergence! We will experiment with this as we set up convergent regimes for GPUs.

@anfals please be aware of this as you do convergence testing on GPU

@gobbleturk gobbleturk added the bug Something isn't working label Sep 17, 2024
@shralex
Copy link
Collaborator

shralex commented Sep 17, 2024

@anfals is this something you're still working on, or already fixed ? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants