This repository has been archived by the owner on Nov 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from LuxDL/fm/start
[WIP] Adding initializers
- Loading branch information
Showing
10 changed files
with
402 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,11 @@ uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" | |
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
||
[compat] | ||
julia = "1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,3 +87,4 @@ plugins: | |
|
||
nav: | ||
- "WeightInitializers.jl": "index.md" | ||
- "API Reference": "api.md" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Weight Initializers | ||
|
||
```@docs | ||
zeros32 | ||
ones32 | ||
rand32 | ||
randn32 | ||
glorot_normal | ||
glorot_uniform | ||
kaiming_normal | ||
kaiming_uniform | ||
truncated_normal | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,13 @@ | ||
module WeightInitializers | ||
|
||
using PartialFunctions, Random, SpecialFunctions, Statistics | ||
|
||
include("utils.jl") | ||
include("initializers.jl") | ||
|
||
export zeros32, ones32, rand32, randn32 | ||
export glorot_normal, glorot_uniform | ||
export kaiming_normal, kaiming_uniform | ||
export truncated_normal | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
""" | ||
zeros32(::AbstractRNG, size...) = zeros(Float32, size...) | ||
Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) | ||
""" | ||
zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...) | ||
|
||
""" | ||
ones32(::AbstractRNG, size...) = ones(Float32, size...) | ||
Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) | ||
""" | ||
ones32(::AbstractRNG, dims...) = ones(Float32, dims...) | ||
|
||
""" | ||
randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) | ||
Return an `Array{Float32}` of random numbers from a standard normal distribution of the | ||
given `size`. | ||
""" | ||
randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) | ||
|
||
""" | ||
rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) | ||
Return an `Array{Float32}` of random numbers from a uniform distribution of the given | ||
`size`. | ||
""" | ||
rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) | ||
|
||
""" | ||
glorot_uniform(rng::AbstractRNG, size...; gain = 1) | ||
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a | ||
uniform distribution on the interval ``[-x, x]``, where | ||
`x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as | ||
Xavier initialization. | ||
# References | ||
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep | ||
feedforward neural networks." _Proceedings of the thirteenth international conference on | ||
artificial intelligence and statistics_. 2010. | ||
""" | ||
function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) | ||
scale = Float32(gain) * sqrt(24.0f0 / sum(_nfan(dims...))) | ||
return (rand(rng, Float32, dims...) .- 0.5f0) .* scale | ||
end | ||
|
||
""" | ||
glorot_normal(rng::AbstractRNG, size...; gain = 1) | ||
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal | ||
distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This method is | ||
described in [1] and also known as Xavier initialization. | ||
# References | ||
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep | ||
feedforward neural networks." _Proceedings of the thirteenth international conference on | ||
artificial intelligence and statistics_. 2010. | ||
""" | ||
function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) | ||
std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) | ||
return randn(rng, Float32, dims...) .* std | ||
end | ||
|
||
""" | ||
kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) | ||
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a | ||
uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`. | ||
# References | ||
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on | ||
imagenet classification." _Proceedings of the IEEE international conference on computer | ||
vision_. 2015. | ||
""" | ||
function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) | ||
bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) | ||
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound | ||
end | ||
|
||
""" | ||
kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) | ||
Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal | ||
distribution standard deviation `gain / sqrt(fan_in)` | ||
# References | ||
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on | ||
imagenet classification." _Proceedings of the IEEE international conference on computer | ||
vision_. 2015. | ||
""" | ||
function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) | ||
std = Float32(gain / sqrt(first(_nfan(dims...)))) | ||
return randn(rng, Float32, dims...) .* std | ||
end | ||
|
||
""" | ||
truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) | ||
Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. | ||
The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. | ||
""" | ||
function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo=-2, hi=2) | ||
if (mean < lo - 2 * std) || (mean > hi + 2 * std) | ||
@warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 | ||
end | ||
l = _norm_cdf((lo - mean) / std) | ||
u = _norm_cdf((hi - mean) / std) | ||
xs = rand(rng, Float32, dims...) | ||
broadcast!(xs, xs) do x | ||
x = x * 2(u - l) + (2l - 1) | ||
x = erfinv(x) | ||
return x = clamp(x * std * √2 + mean, lo, hi) | ||
end | ||
return xs | ||
end | ||
|
||
# Default Fallbacks for all functions | ||
for initializer in (:zeros32, | ||
:ones32, | ||
:randn32, | ||
:rand32, | ||
:glorot_uniform, | ||
:glorot_normal, | ||
:kaiming_uniform, | ||
:kaiming_normal, | ||
:truncated_normal) | ||
@eval function ($initializer)(dims::Integer...; kwargs...) | ||
return $initializer(_default_rng(), dims...; kwargs...) | ||
end | ||
@eval function ($initializer)(rng::AbstractRNG; kwargs...) | ||
return _partial_apply($initializer, (rng, (; kwargs...))) | ||
end | ||
@eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
@inline _nfan() = 1, 1 # fan_in, fan_out | ||
@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix | ||
@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices | ||
@inline _nfan(dims::Tuple) = _nfan(dims...) | ||
@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels | ||
_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) | ||
|
||
function _default_rng() | ||
@static if VERSION >= v"1.7" | ||
return Xoshiro(1234) | ||
else | ||
return MersenneTwister(1234) | ||
end | ||
end | ||
|
||
# This is needed if using `PartialFunctions.$` inside @eval block | ||
_partial_apply(fn, inp) = fn$inp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
f85f994
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
f85f994
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/86004
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: