Skip to content

Commit

Permalink
Support for lecun normal weight initialization (#2311)
Browse files Browse the repository at this point in the history
* Support for lecun normal weight initialization

* add test

* add to docs

* Update utils.jl

* fixup

* fix rtol

* doctests

---------

Co-authored-by: Michael Abbott <[email protected]>
  • Loading branch information
RohitRathore1 and mcabbott authored Oct 28, 2024
1 parent 32db5d4 commit 9f56f51
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/src/reference/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Flux.glorot_normal
Flux.kaiming_uniform
Flux.kaiming_normal
Flux.truncated_normal
Flux.lecun_normal
Flux.orthogonal
Flux.sparse_init
Flux.identity_init
Expand Down
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ export MPIBackend, NCCLBackend, DistributedUtils
kaiming_uniform,
kaiming_normal,
truncated_normal,
lecun_normal,
orthogonal,
sparse_init,
identity_init,
Expand Down
42 changes: 42 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,48 @@ truncated_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwa

ChainRulesCore.@non_differentiable truncated_normal(::Any...)

"""
lecun_normal([rng], size...) -> Array
lecun_normal([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a truncated normal
distribution centered on 0 with stddev `sqrt(1 / fan_in)`, where `fan_in` is the number of input units
in the weight tensor.
# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> using Statistics
julia> round(std(Flux.lecun_normal(10, 1000)), digits=3)
0.032f0
julia> round(std(Flux.lecun_normal(1000, 10)), digits=3)
0.32f0
julia> round(std(Flux.lecun_normal(1000, 1000)), digits=3)
0.032f0
julia> Dense(10 => 1000, selu; init = Flux.lecun_normal())
Dense(10 => 1000, selu) # 11_000 parameters
julia> round(std(ans.weight), digits=3)
0.313f0
```
# References
[1] Lecun, Yann, et al. "Efficient backprop." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 9-48.
"""
function lecun_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
std = Float32(gain)*sqrt(1.0f0 / first(nfan(dims...))) # calculates the standard deviation based on the `fan_in` value
return truncated_normal(rng, dims...; mean=0, std=std)
end

lecun_normal(dims::Integer...; kwargs...) = lecun_normal(default_rng(), dims...; kwargs...)
lecun_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> lecun_normal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable lecun_normal(::Any...)

"""
orthogonal([rng], size...; gain = 1) -> Array
orthogonal([rng]; kw...) -> Function
Expand Down
9 changes: 7 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, lecun_normal,
sparse_init, identity_init, unstack, batch, unbatch,
unsqueeze, params, loadmodel!
using MLUtils
Expand Down Expand Up @@ -75,7 +75,7 @@ end
kaiming_uniform, kaiming_normal,
orthogonal,
sparse_init,
truncated_normal,
truncated_normal, lecun_normal,
identity_init,
Flux.rand32,
Flux.randn32,
Expand Down Expand Up @@ -192,6 +192,11 @@ end
end
end

@testset "lecun_normal" begin
@test std(Flux.lecun_normal(10, 1000)) 0.032f0 rtol=0.1
@test std(Flux.lecun_normal(1000, 10)) 0.317f0 rtol=0.1
end

@testset "Partial application" begin
partial_ku = kaiming_uniform(gain=1e9)
@test maximum(partial_ku(8, 8)) > 1e9 / 2
Expand Down

0 comments on commit 9f56f51

Please sign in to comment.