diff --git a/Project.toml b/Project.toml index 8ff3b13..6bffc6f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,5 +3,11 @@ uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] version = "0.1.0" +[deps] +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [compat] julia = "1.6" diff --git a/README.md b/README.md index 3e3e641..56db605 100644 --- a/README.md +++ b/README.md @@ -11,4 +11,71 @@ [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -`WeightInitializers.jl` provides common weight initialization schemes for deep learning models. +This package is a light dependency providing common weight initialization schemes for deep learning models. + +## Example +These code snippets are just provided to give a high level overview +of the functionalities of the package. +Please refer to the [stable documentation](https://luxdl.github.io/WeightInitializers.jl/stable) for mode information +about the package. The +[under development documentation](https://luxdl.github.io/WeightInitializers.jl/dev) +provides information on features not yet released. + +```julia +using WeightInitializers, Random + +# Fixing rng +rng = Random.MersenneTwister(42) + +# Explicit rng call +weights = kaiming_normal(rng, 2, 5) +#2×5 Matrix{Float32}: +# -0.351662 0.0171745 1.12442 -0.296372 -1.67094 +# -0.281053 -0.18941 -0.724099 0.0987538 0.634549 + +# Default rng call +weights = kaiming_normal(2, 5) +#2×5 Matrix{Float32}: +# -0.227513 -0.265372 0.265788 1.29955 -0.192836 +# 0.687611 0.454679 -0.433656 0.20548 0.292002 + +# Passing kwargs (if needed) with explicit rng call +weights_cl = kaiming_normal(rng; gain=1.0) +weights = weights_cl(rng, 2, 5) +#2×5 Matrix{Float32}: +# 0.484056 0.231723 0.164379 0.306147 0.18365 +# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 + +# Passing kwargs (if needed) with default rng call +weights_cl = kaiming_normal(; gain=1.0) +weights = weights_cl(2, 5) +#2×5 Matrix{Float32}: +# -0.160876 -0.187646 0.18794 0.918918 -0.136356 +# 0.486214 0.321506 -0.306641 0.145296 0.206476 +``` + +## API + +The package is meant to be working with deep learning +libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. + +```julia +weights = init(rng, dims...) +``` + +The `rng` is optional, if not specified a default one will be used. + +```julia +weights = init(dims...) +``` + +If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) +and the keywords to get in return a function behaving like the two examples above. + +```julia +weights_init = init(rng; kwargs...) +weights = weights_init(rng, dims...) +# or +weights_init = init(; kwargs...) +weights = weights_init(dims...) +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 2ad45a6..77b6ad3 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -87,3 +87,4 @@ plugins: nav: - "WeightInitializers.jl": "index.md" + - "API Reference": "api.md" diff --git a/docs/src/api.md b/docs/src/api.md new file mode 100644 index 0000000..4016aa4 --- /dev/null +++ b/docs/src/api.md @@ -0,0 +1,13 @@ +# Weight Initializers + +```@docs +zeros32 +ones32 +rand32 +randn32 +glorot_normal +glorot_uniform +kaiming_normal +kaiming_uniform +truncated_normal +``` diff --git a/docs/src/index.md b/docs/src/index.md index dc2fbb3..345f450 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -17,10 +17,59 @@ CurrentModule = WeightInitializers ``` -## API Reference +```julia +using WeightInitializers, Random -### Index +# Fixing rng +rng = Random.MersenneTwister(42) -```@index -Pages = ["index.md"] +# Explicit rng call +weights = kaiming_normal(rng, 2, 5) +#2×5 Matrix{Float32}: +# -0.351662 0.0171745 1.12442 -0.296372 -1.67094 +# -0.281053 -0.18941 -0.724099 0.0987538 0.634549 + +# Default rng call +weights = kaiming_normal(2, 5) +#2×5 Matrix{Float32}: +# -0.227513 -0.265372 0.265788 1.29955 -0.192836 +# 0.687611 0.454679 -0.433656 0.20548 0.292002 + +# Passing kwargs (if needed) with explicit rng call +weights_cl = kaiming_normal(rng; gain=1.0) +weights = weights_cl(rng, 2, 5) +#2×5 Matrix{Float32}: +# 0.484056 0.231723 0.164379 0.306147 0.18365 +# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 + +# Passing kwargs (if needed) with default rng call +weights_cl = kaiming_normal(; gain=1.0) +weights = weights_cl(2, 5) +#2×5 Matrix{Float32}: +# -0.160876 -0.187646 0.18794 0.918918 -0.136356 +# 0.486214 0.321506 -0.306641 0.145296 0.206476 ``` + +## Quick examples + +The package is meant to be working with deep learning +libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. +```julia +weights = init(rng, dims...) +``` + +The `rng` is optional, if not specified a default one will be used. +```julia +weights = init(dims...) +``` + +If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) +and the keywords to get in return a function behaving like the +two examples above. +```julia +weights_init = init(rng; kwargs...) +weights = weights_init(rng, dims...) +# or +weights_init = init(; kwargs...) +weights = weights_init(dims...) +``` \ No newline at end of file diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index a771033..6d70386 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,3 +1,13 @@ module WeightInitializers +using PartialFunctions, Random, SpecialFunctions, Statistics + +include("utils.jl") +include("initializers.jl") + +export zeros32, ones32, rand32, randn32 +export glorot_normal, glorot_uniform +export kaiming_normal, kaiming_uniform +export truncated_normal + end diff --git a/src/initializers.jl b/src/initializers.jl new file mode 100644 index 0000000..b05c38c --- /dev/null +++ b/src/initializers.jl @@ -0,0 +1,140 @@ +""" + zeros32(::AbstractRNG, size...) = zeros(Float32, size...) + +Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) +""" +zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...) + +""" + ones32(::AbstractRNG, size...) = ones(Float32, size...) + +Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) +""" +ones32(::AbstractRNG, dims...) = ones(Float32, dims...) + +""" + randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) + +Return an `Array{Float32}` of random numbers from a standard normal distribution of the +given `size`. +""" +randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) + +""" + rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) + +Return an `Array{Float32}` of random numbers from a uniform distribution of the given +`size`. +""" +rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) + +""" + glorot_uniform(rng::AbstractRNG, size...; gain = 1) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a +uniform distribution on the interval ``[-x, x]``, where +`x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as +Xavier initialization. + +# References + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep +feedforward neural networks." _Proceedings of the thirteenth international conference on +artificial intelligence and statistics_. 2010. +""" +function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) + scale = Float32(gain) * sqrt(24.0f0 / sum(_nfan(dims...))) + return (rand(rng, Float32, dims...) .- 0.5f0) .* scale +end + +""" + glorot_normal(rng::AbstractRNG, size...; gain = 1) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal +distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This method is +described in [1] and also known as Xavier initialization. + +# References + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep +feedforward neural networks." _Proceedings of the thirteenth international conference on +artificial intelligence and statistics_. 2010. +""" +function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) + std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) + return randn(rng, Float32, dims...) .* std +end + +""" + kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a +uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`. + +# References + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on +imagenet classification." _Proceedings of the IEEE international conference on computer +vision_. 2015. +""" +function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) + bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) + return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound +end + +""" + kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) + +Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal +distribution standard deviation `gain / sqrt(fan_in)` + +# References + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on +imagenet classification." _Proceedings of the IEEE international conference on computer +vision_. 2015. +""" +function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) + std = Float32(gain / sqrt(first(_nfan(dims...)))) + return randn(rng, Float32, dims...) .* std +end + +""" + truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) + +Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. +The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. +""" +function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo=-2, hi=2) + if (mean < lo - 2 * std) || (mean > hi + 2 * std) + @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 + end + l = _norm_cdf((lo - mean) / std) + u = _norm_cdf((hi - mean) / std) + xs = rand(rng, Float32, dims...) + broadcast!(xs, xs) do x + x = x * 2(u - l) + (2l - 1) + x = erfinv(x) + return x = clamp(x * std * √2 + mean, lo, hi) + end + return xs +end + +# Default Fallbacks for all functions +for initializer in (:zeros32, + :ones32, + :randn32, + :rand32, + :glorot_uniform, + :glorot_normal, + :kaiming_uniform, + :kaiming_normal, + :truncated_normal) + @eval function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG; kwargs...) + return _partial_apply($initializer, (rng, (; kwargs...))) + end + @eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...)) +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..b26253e --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,17 @@ +@inline _nfan() = 1, 1 # fan_in, fan_out +@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix +@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices +@inline _nfan(dims::Tuple) = _nfan(dims...) +@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels +_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) + +function _default_rng() + @static if VERSION >= v"1.7" + return Xoshiro(1234) + else + return MersenneTwister(1234) + end +end + +# This is needed if using `PartialFunctions.$` inside @eval block +_partial_apply(fn, inp) = fn$inp diff --git a/test/Project.toml b/test/Project.toml index da83f97..95e58e3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,7 @@ [deps] +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/runtests.jl b/test/runtests.jl index 3417cb7..7120d1e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1 +1,91 @@ -using WeightInitializers, Test +using WeightInitializers, Test, SafeTestsets, StableRNGs, Statistics + +const rng = StableRNG(12345) + +@testset "WeightInitializers.jl Tests" begin + @testset "_nfan" begin + # Fallback + @test WeightInitializers._nfan() == (1, 1) + # Vector + @test WeightInitializers._nfan(4) == (1, 4) + # Matrix + @test WeightInitializers._nfan(4, 5) == (5, 4) + # Tuple + @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) + # Convolution + @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) + end + + @testset "Sizes and Types: $init" for init in [ + zeros32, + ones32, + rand32, + randn32, + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, + truncated_normal, + ] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + end + + @testset "Closure: $init" for init in [ + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, + truncated_normal, + ] + cl = init(;) + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + + @testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 + + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + @test 0.9σ2 < std(v) < 1.1σ2 + end + # Type + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5)) == Float32 + end + + @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = WeightInitializers._nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 + end +end