Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Epsilon change in normalise for stability #2421

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

billera
Copy link

@billera billera commented Apr 7, 2024

Normalise allows for an optional epsilon term aimed towards improving numerical stability. Previously the epsilon was added after computing the standard deviation of the input. The standard deviation computation involves a square root, leading to NaN's in gradients dependent on normalise when the variance is very low, and for instance LayerNorms applied to low variance inputs will result in NaN gradients. By first computing the variance and taking the square root after adding epsilon^2 (squaring to preserve scale), we prevent NaN's in gradients at low variance. See the following example with LayerNorm in the current patch.

using Flux 
using Zygote 

ln = LayerNorm(256; eps = 1f-3)
for i in 1:10 
    x = ones(Float32, 256) .+ randn(Float32, 256) .* 10f0^(-i)
    l, gs = Zygote.withjacobian(ln, x)
    @show maximum(gs[1])
end


>>> maximum(gs[1]) = 9.44178f0
>>> maximum(gs[1]) = 95.85736f0
>>> maximum(gs[1]) = 477.4946f0
>>> maximum(gs[1]) = 910.05457f0
>>> maximum(gs[1]) = 985.8402f0
>>> maximum(gs[1]) = 995.0282f0
>>> maximum(gs[1]) = 995.9835f0
>>> maximum(gs[1]) = NaN32
>>> maximum(gs[1]) = NaN32
>>> maximum(gs[1]) = NaN32

We observe that while the gradients are fixed at low variance due to the epsilon addition in the denominator, this does prevent NaN's, due to the non-padded square root in the std computation. But, when using the updated normalise, these NaN's dissapear,

>>> maximum(gs[1]) = 9.531697f0
>>> maximum(gs[1]) = 105.468056f0
>>> maximum(gs[1]) = 674.7051f0
>>> maximum(gs[1]) = 991.67163f0
>>> maximum(gs[1]) = 996.03973f0
>>> maximum(gs[1]) = 996.09314f0
>>> maximum(gs[1]) = 996.0937f0
>>> maximum(gs[1]) = 996.0937f0
>>> maximum(gs[1]) = 996.0937f0
>>> maximum(gs[1]) = 996.0937f0

and remain fixed to the implicitly capped value. A simple test verifying this computation's equivalence with the previous one (modulo the differences at very low standard deviations) could be added if desired.

@billera billera closed this Apr 7, 2024
@billera billera reopened this Apr 7, 2024
Copy link

codecov bot commented Apr 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 74.02%. Comparing base (caa1cee) to head (b600f7a).

Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2421       +/-   ##
===========================================
+ Coverage   46.13%   74.02%   +27.88%     
===========================================
  Files          32       32               
  Lines        1877     1925       +48     
===========================================
+ Hits          866     1425      +559     
+ Misses       1011      500      -511     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@CarloLucibello
Copy link
Member

I agree with this change and pytorch does the same thing. It should considered a breaking change thous, so let's wait for when we are near v01.5 before merging.

@CarloLucibello CarloLucibello added this to the v0.15 milestone Apr 7, 2024
Co-authored-by: Carlo Lucibello <[email protected]>
@mcabbott
Copy link
Member

mcabbott commented Apr 9, 2024

Can this have a test with input which triggers the NaN behaviour before?

Ideally testing not just the function, but also LayerNorm, maybe BatchNorm, anything which uses this internally. Then if the implementation of these layers finally gets replaced, it will be harder to lose the change.

@ToucheSir
Copy link
Member

Putting a backlink to #2096 because this work should close that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants