From 473338847a6aad29fdbea481334b94f4f1d20997 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 8 Jun 2023 16:58:13 +0200 Subject: [PATCH 01/11] F/Lux initializers --- Project.toml | 3 + src/WeightInitializers.jl | 5 ++ src/inits.jl | 139 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+) create mode 100644 src/inits.jl diff --git a/Project.toml b/Project.toml index 8ff3b13..e958eb3 100644 --- a/Project.toml +++ b/Project.toml @@ -3,5 +3,8 @@ uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] version = "0.1.0" +[deps] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + [compat] julia = "1.6" diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index a771033..120bb1e 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,3 +1,8 @@ module WeightInitializers +using Random +include("inits.jl") +export zeros32, ones32, rand32, randn32 +export glorot_normal, glorot_uniform +export kaiming_normal, kaiming_uniform end diff --git a/src/inits.jl b/src/inits.jl new file mode 100644 index 0000000..10798fc --- /dev/null +++ b/src/inits.jl @@ -0,0 +1,139 @@ + +@inline _nfan() = 1, 1 # fan_in, fan_out +@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix +@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices +@inline _nfan(dims::Tuple) = _nfan(dims...) + +function _default_rng() + @static if VERSION >= v"1.7" + return Xoshiro(1234) + else + return MersenneTwister(1234) + end +end + +""" + default_rng_value() + +Create an instance of the default RNG depending on Julia's version. + - Julia version is < 1.7: `MersenneTwister(1234)` + - Julia version is >= 1.7: `Xoshiro(1234)` +""" +_default_rng + +""" + zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) + +Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) +""" +zeros32(rng::AbstractRNG, dims...) = zeros(rng, Float32, dims...) +zeros32(dims...) = zeros32(_default_rng(), dims...) +Base.zeros(rng::AbstractRNG, args...) = zeros(args...) +""" + ones32(rng::AbstractRNG, size...) = ones(Float32, size...) + +Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) +""" +ones32(rng::AbstractRNG, dims...) = ones(rng, Float32, dims...) +ones32(dims...) = ones32(_default_rng(), dims...) +Base.ones(rng::AbstractRNG, dims...) = ones(dims...) + +""" + randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) + +Return an `Array{Float32}` of random numbers from a standard normal distribution of the +given `size`. +""" +randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) +randn32(dims...) = randn32(_default_rng(), dims...) + +""" + rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) + +Return an `Array{Float32}` of random numbers from a uniform distribution of the given +`size`. +""" +rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) +rand32(dims...) = rand32(_default_rng(), dims...) + +""" + glorot_uniform(rng::AbstractRNG, size...; gain = 1) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a +uniform distribution on the interval ``[-x, x]``, where +`x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as +Xavier initialization. + +# References + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep +feedforward neural networks." _Proceedings of the thirteenth international conference on +artificial intelligence and statistics_. 2010. +""" +function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) + scale = Float32(gain) * sqrt(24.0f0 / sum(_nfan(dims...))) + return (rand(rng, Float32, dims...) .- 0.5f0) .* scale +end +glorot_uniform(dims::Integer...; kw...) = glorot_uniform(_default_rng(), dims...; kwargs...) +glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) + +""" + glorot_normal(rng::AbstractRNG, size...; gain = 1) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal +distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This method is +described in [1] and also known as Xavier initialization. + +# References + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep +feedforward neural networks." _Proceedings of the thirteenth international conference on +artificial intelligence and statistics_. 2010. +""" +function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) + std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) + return randn(rng, Float32, dims...) .* std +end +glorot_normal(dims::Integer...; kwargs...) = glorot_normal(_default_rng(), dims...; kwargs...) +glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) + + +""" + kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a +uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`. + +# References + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on +imagenet classification." _Proceedings of the IEEE international conference on computer +vision_. 2015. +""" +function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) + bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) + return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound +end +kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(_default_rng(), dims...; kwargs...) +kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) + + +""" + kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) + +Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal +distribution standard deviation `gain / sqrt(fan_in)` + +# References + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on +imagenet classification." _Proceedings of the IEEE international conference on computer +vision_. 2015. +""" +function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) + std = Float32(gain / sqrt(first(_nfan(dims...)))) + return randn(rng, Float32, dims...) .* std +end + +kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(_default_rng(), dims...; kwargs...) +kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) From cff7354b64ecd101eac85bde997a7fa617a5d00a Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 8 Jun 2023 17:03:49 +0200 Subject: [PATCH 02/11] small changes --- src/inits.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/inits.jl b/src/inits.jl index 10798fc..ee6c1d1 100644 --- a/src/inits.jl +++ b/src/inits.jl @@ -11,15 +11,7 @@ function _default_rng() return MersenneTwister(1234) end end - -""" - default_rng_value() - -Create an instance of the default RNG depending on Julia's version. - - Julia version is < 1.7: `MersenneTwister(1234)` - - Julia version is >= 1.7: `Xoshiro(1234)` -""" -_default_rng + """ zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) From 6619885542fcbb71478b882568ca0dc3decdab87 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 9 Jun 2023 14:32:41 +0200 Subject: [PATCH 03/11] sketch for tests --- src/inits.jl | 31 +++++++++++++++++-------- test/Project.toml | 2 ++ test/runtests.jl | 58 ++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/src/inits.jl b/src/inits.jl index ee6c1d1..6965186 100644 --- a/src/inits.jl +++ b/src/inits.jl @@ -12,7 +12,6 @@ function _default_rng() end end - """ zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) @@ -67,7 +66,9 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) return (rand(rng, Float32, dims...) .- 0.5f0) .* scale end glorot_uniform(dims::Integer...; kw...) = glorot_uniform(_default_rng(), dims...; kwargs...) -glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +function glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) + return (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +end """ glorot_normal(rng::AbstractRNG, size...; gain = 1) @@ -86,9 +87,12 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) return randn(rng, Float32, dims...) .* std end -glorot_normal(dims::Integer...; kwargs...) = glorot_normal(_default_rng(), dims...; kwargs...) -glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) - +function glorot_normal(dims::Integer...; kwargs...) + return glorot_normal(_default_rng(), dims...; kwargs...) +end +function glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) + return (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) +end """ kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) @@ -106,9 +110,12 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0 bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound end -kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(_default_rng(), dims...; kwargs...) -kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) - +function kaiming_uniform(dims::Integer...; kwargs...) + return kaiming_uniform(_default_rng(), dims...; kwargs...) +end +function kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) + return (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +end """ kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) @@ -127,5 +134,9 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) return randn(rng, Float32, dims...) .* std end -kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(_default_rng(), dims...; kwargs...) -kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) +function kaiming_normal(dims::Integer...; kwargs...) + return kaiming_normal(_default_rng(), dims...; kwargs...) +end +function kaiming_normal(rng::AbstractRNG; init_kwargs...) + return (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) +end diff --git a/test/Project.toml b/test/Project.toml index da83f97..aa8310d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/runtests.jl b/test/runtests.jl index 3417cb7..70bc9a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1 +1,57 @@ -using WeightInitializers, Test +using WeightInitializers, Test, SafeTestsets, StableRNGs + +const rng = StableRNG(12345) + +@testset "inits: $init" for init in [ + zeros32, + ones32, + rand32, + randn32, + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, +] + #sizes + @test size(init(3)) == (3,) + @test size(rng, init(3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + #type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + #closure #TODO @MartinuzzFrancesco + cl = init(rng) +end + +@testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 + + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + @test 0.9σ2 < std(v) < 1.1σ2 + end + # + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5)) == Float32 +end + +@testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 +end From 6127455727541cb9b622227e4ca090410bde1f91 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 10 Jun 2023 15:17:35 +0200 Subject: [PATCH 04/11] more tests --- Project.toml | 1 + src/WeightInitializers.jl | 3 +++ src/inits.jl | 17 ++++++++++++++--- test/Project.toml | 1 + test/runtests.jl | 29 ++++++++++++++++++++++------- 5 files changed, 41 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index e958eb3..5416a83 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] julia = "1.6" diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index 120bb1e..f226909 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,8 +1,11 @@ module WeightInitializers using Random +using Statistics + include("inits.jl") export zeros32, ones32, rand32, randn32 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform + end diff --git a/src/inits.jl b/src/inits.jl index 6965186..f0671a4 100644 --- a/src/inits.jl +++ b/src/inits.jl @@ -1,8 +1,8 @@ - @inline _nfan() = 1, 1 # fan_in, fan_out @inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix @inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices @inline _nfan(dims::Tuple) = _nfan(dims...) +@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels function _default_rng() @static if VERSION >= v"1.7" @@ -19,7 +19,7 @@ Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) """ zeros32(rng::AbstractRNG, dims...) = zeros(rng, Float32, dims...) zeros32(dims...) = zeros32(_default_rng(), dims...) -Base.zeros(rng::AbstractRNG, args...) = zeros(args...) +Base.zeros(rng::AbstractRNG, dims...) = zeros(dims...) """ ones32(rng::AbstractRNG, size...) = ones(Float32, size...) @@ -37,6 +37,7 @@ given `size`. """ randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) randn32(dims...) = randn32(_default_rng(), dims...) +randn32(rng::AbstractRNG=_default_rng()) = (dims...,) -> randn32(rng, dims...) """ rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) @@ -46,6 +47,7 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the """ rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) rand32(dims...) = rand32(_default_rng(), dims...) +rand32(rng::AbstractRNG=_default_rng()) = (dims...,) -> rand32(rng, dims...) """ glorot_uniform(rng::AbstractRNG, size...; gain = 1) @@ -65,7 +67,11 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) scale = Float32(gain) * sqrt(24.0f0 / sum(_nfan(dims...))) return (rand(rng, Float32, dims...) .- 0.5f0) .* scale end -glorot_uniform(dims::Integer...; kw...) = glorot_uniform(_default_rng(), dims...; kwargs...) + +function glorot_uniform(dims::Integer...; kwargs...) + return glorot_uniform(_default_rng(), dims...; kwargs...) +end + function glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) return (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) end @@ -87,9 +93,11 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...))) return randn(rng, Float32, dims...) .* std end + function glorot_normal(dims::Integer...; kwargs...) return glorot_normal(_default_rng(), dims...; kwargs...) end + function glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) return (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) end @@ -110,9 +118,11 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0 bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...)))) return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound end + function kaiming_uniform(dims::Integer...; kwargs...) return kaiming_uniform(_default_rng(), dims...; kwargs...) end + function kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) return (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) end @@ -137,6 +147,7 @@ end function kaiming_normal(dims::Integer...; kwargs...) return kaiming_normal(_default_rng(), dims...; kwargs...) end + function kaiming_normal(rng::AbstractRNG; init_kwargs...) return (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) end diff --git a/test/Project.toml b/test/Project.toml index aa8310d..95e58e3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/runtests.jl b/test/runtests.jl index 70bc9a1..0e8d39b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,8 @@ -using WeightInitializers, Test, SafeTestsets, StableRNGs +using WeightInitializers, Test, SafeTestsets, StableRNGs, Statistics const rng = StableRNG(12345) -@testset "inits: $init" for init in [ +@testset "Sizes and Types: $init" for init in [ zeros32, ones32, rand32, @@ -12,18 +12,33 @@ const rng = StableRNG(12345) glorot_uniform, glorot_normal, ] - #sizes + # Sizes @test size(init(3)) == (3,) - @test size(rng, init(3)) == (3,) + @test size(init(rng, 3)) == (3,) @test size(init(3, 4)) == (3, 4) @test size(init(rng, 3, 4)) == (3, 4) @test size(init(3, 4, 5)) == (3, 4, 5) @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - #type + # Type @test eltype(init(rng, 4, 2)) == Float32 @test eltype(init(4, 2)) == Float32 - #closure #TODO @MartinuzzFrancesco +end + +@testset "Closure: $init" for init in [ + rand32, + randn32, + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, +] cl = init(rng) + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 end @testset "kaiming" begin @@ -49,7 +64,7 @@ end # variance ≈ 2/(fan_in + fan_out) for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] v = init(dims...) - fan_in, fan_out = nfan(dims...) + fan_in, fan_out = WeightInitializers._nfan(dims...) σ2 = 2 / (fan_in + fan_out) @test 0.9σ2 < var(v) < 1.1σ2 end From 26bddd1407e53b326618255ac118c28d58e8a312 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 10 Jun 2023 16:12:46 +0200 Subject: [PATCH 05/11] small fixes, readme --- README.md | 67 ++++++++++++++++++++++++++++++++++++++++++++++- docs/src/index.md | 57 +++++++++++++++++++++++++++++++++++++--- src/inits.jl | 58 +++++++++++++++++++++++++++++++++------- test/runtests.jl | 4 +++ 4 files changed, 172 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 3e3e641..9f7762c 100644 --- a/README.md +++ b/README.md @@ -11,4 +11,69 @@ [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) -`WeightInitializers.jl` provides common weight initialization schemes for deep learning models. +This package is a light dependency providing common weight initialization schemes for deep learning models. + +## Example +These code snippets are just provided to give a high level overview +of the functionalities of the package. +Please refer to the [stable documentation](https://luxdl.github.io/WeightInitializers.jl/stable) for mode information +about the package. The +[under development documentation](https://luxdl.github.io/WeightInitializers.jl/dev) +provides information on features not yet released. + +```julia +using WeightInitializers, Random + +# Fixing rng +rng = Random.MersenneTwister(42) + +# Explicit rng call +weights = kaiming_normal(rng, 2, 5) +#2×5 Matrix{Float32}: +# -0.351662 0.0171745 1.12442 -0.296372 -1.67094 +# -0.281053 -0.18941 -0.724099 0.0987538 0.634549 + +# Default rng call +weights = kaiming_normal(2, 5) +#2×5 Matrix{Float32}: +# -0.227513 -0.265372 0.265788 1.29955 -0.192836 +# 0.687611 0.454679 -0.433656 0.20548 0.292002 + +# Passing kwargs (if needed) with explicit rng call +weights_cl = kaiming_normal(rng; gain=1.0) +weights = weights_cl(rng, 2, 5) +#2×5 Matrix{Float32}: +# 0.484056 0.231723 0.164379 0.306147 0.18365 +# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 + +# Passing kwargs (if needed) with default rng call +weights_cl = kaiming_normal(; gain=1.0) +weights = weights_cl(2, 5) +#2×5 Matrix{Float32}: +# -0.160876 -0.187646 0.18794 0.918918 -0.136356 +# 0.486214 0.321506 -0.306641 0.145296 0.206476 +``` + +## API + +The package is meant to be working with deep learning +libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. +```julia +weights = init(rng, dims...) +``` + +The `rng` is optional, if not specified a default one will be used. +```julia +weights = init(dims...) +``` + +If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) +and the keywords to get in return a function behaving like the +two examples above. +```julia +weights_init = init(rng; kwargs...) +weights = weights_init(rng, dims...) +# or +weights_init = init(; kwargs...) +weights = weights_init(dims...) +``` diff --git a/docs/src/index.md b/docs/src/index.md index dc2fbb3..345f450 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -17,10 +17,59 @@ CurrentModule = WeightInitializers ``` -## API Reference +```julia +using WeightInitializers, Random -### Index +# Fixing rng +rng = Random.MersenneTwister(42) -```@index -Pages = ["index.md"] +# Explicit rng call +weights = kaiming_normal(rng, 2, 5) +#2×5 Matrix{Float32}: +# -0.351662 0.0171745 1.12442 -0.296372 -1.67094 +# -0.281053 -0.18941 -0.724099 0.0987538 0.634549 + +# Default rng call +weights = kaiming_normal(2, 5) +#2×5 Matrix{Float32}: +# -0.227513 -0.265372 0.265788 1.29955 -0.192836 +# 0.687611 0.454679 -0.433656 0.20548 0.292002 + +# Passing kwargs (if needed) with explicit rng call +weights_cl = kaiming_normal(rng; gain=1.0) +weights = weights_cl(rng, 2, 5) +#2×5 Matrix{Float32}: +# 0.484056 0.231723 0.164379 0.306147 0.18365 +# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 + +# Passing kwargs (if needed) with default rng call +weights_cl = kaiming_normal(; gain=1.0) +weights = weights_cl(2, 5) +#2×5 Matrix{Float32}: +# -0.160876 -0.187646 0.18794 0.918918 -0.136356 +# 0.486214 0.321506 -0.306641 0.145296 0.206476 ``` + +## Quick examples + +The package is meant to be working with deep learning +libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. +```julia +weights = init(rng, dims...) +``` + +The `rng` is optional, if not specified a default one will be used. +```julia +weights = init(dims...) +``` + +If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) +and the keywords to get in return a function behaving like the +two examples above. +```julia +weights_init = init(rng; kwargs...) +weights = weights_init(rng, dims...) +# or +weights_init = init(; kwargs...) +weights = weights_init(dims...) +``` \ No newline at end of file diff --git a/src/inits.jl b/src/inits.jl index f0671a4..15d490b 100644 --- a/src/inits.jl +++ b/src/inits.jl @@ -37,7 +37,8 @@ given `size`. """ randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) randn32(dims...) = randn32(_default_rng(), dims...) -randn32(rng::AbstractRNG=_default_rng()) = (dims...,) -> randn32(rng, dims...) +randn32(rng::AbstractRNG) = (rng, dims...) -> randn32(rng, dims...) +randn32() = (dims...,) -> randn32(_default_rng(), dims...) """ rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) @@ -47,7 +48,8 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the """ rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) rand32(dims...) = rand32(_default_rng(), dims...) -rand32(rng::AbstractRNG=_default_rng()) = (dims...,) -> rand32(rng, dims...) +rand32(rng::AbstractRNG) = (rng, dims...) -> rand32(rng, dims...) +rand32() = (dims...,) -> rand32(_default_rng(), dims...) """ glorot_uniform(rng::AbstractRNG, size...; gain = 1) @@ -72,8 +74,18 @@ function glorot_uniform(dims::Integer...; kwargs...) return glorot_uniform(_default_rng(), dims...; kwargs...) end -function glorot_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) - return (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +function glorot_uniform(rng::AbstractRNG; init_kwargs...) + return (rng, dims...; kwargs...) -> glorot_uniform(rng, + dims...; + init_kwargs..., + kwargs...) +end + +function glorot_uniform(; init_kwargs...) + return (dims...; kwargs...) -> glorot_uniform(_default_rng(), + dims...; + init_kwargs..., + kwargs...) end """ @@ -98,10 +110,19 @@ function glorot_normal(dims::Integer...; kwargs...) return glorot_normal(_default_rng(), dims...; kwargs...) end -function glorot_normal(rng::AbstractRNG=_default_rng(); init_kwargs...) - return (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) +function glorot_normal(rng::AbstractRNG; init_kwargs...) + return (rng, dims...; kwargs...) -> glorot_normal(rng, + dims...; + init_kwargs..., + kwargs...) end +function glorot_normal(; init_kwargs...) + return (dims...; kwargs...) -> glorot_normal(_default_rng(), + dims...; + init_kwargs..., + kwargs...) +end """ kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) @@ -123,10 +144,19 @@ function kaiming_uniform(dims::Integer...; kwargs...) return kaiming_uniform(_default_rng(), dims...; kwargs...) end -function kaiming_uniform(rng::AbstractRNG=_default_rng(); init_kwargs...) - return (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +function kaiming_uniform(rng::AbstractRNG; init_kwargs...) + return (rng, dims...; kwargs...) -> kaiming_uniform(rng, + dims...; + init_kwargs..., + kwargs...) end +function kaiming_uniform(; init_kwargs...) + return (dims...; kwargs...) -> kaiming_uniform(_default_rng(), + dims...; + init_kwargs..., + kwargs...) +end """ kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) @@ -149,5 +179,15 @@ function kaiming_normal(dims::Integer...; kwargs...) end function kaiming_normal(rng::AbstractRNG; init_kwargs...) - return (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) + return (rng, dims...; kwargs...) -> kaiming_normal(rng, + dims...; + init_kwargs..., + kwargs...) +end + +function kaiming_normal(; init_kwargs...) + return (dims...; kwargs...) -> kaiming_normal(_default_rng(), + dims...; + init_kwargs..., + kwargs...) end diff --git a/test/runtests.jl b/test/runtests.jl index 0e8d39b..4ee5462 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,10 +35,14 @@ end cl = init(rng) # Sizes @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) # Type @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 end @testset "kaiming" begin From 1f7e28cbfeab317023560501d1200e507e8d9f8a Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 10 Jun 2023 16:19:04 +0200 Subject: [PATCH 06/11] api docs --- docs/src/api.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 docs/src/api.md diff --git a/docs/src/api.md b/docs/src/api.md new file mode 100644 index 0000000..83a0a5b --- /dev/null +++ b/docs/src/api.md @@ -0,0 +1,12 @@ +# Weight Initializers + +```@docs +zeros32 +ones32 +rand32 +randn32 +glorot_normal +glorot_uniform +kaiming_normal +kaiming_uniform +``` From 949b70118a61d562dc09509846bbdefd5573cdb1 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 10 Jun 2023 16:22:51 +0200 Subject: [PATCH 07/11] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5416a83..429dd19 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" From bd44653cea71106bf9505ec2ed13ab838b6c8b81 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 12 Jun 2023 23:00:28 +0200 Subject: [PATCH 08/11] added truncated_normal --- Project.toml | 1 + docs/src/api.md | 1 + src/WeightInitializers.jl | 2 ++ src/inits.jl | 38 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 5 files changed, 44 insertions(+) diff --git a/Project.toml b/Project.toml index 429dd19..cd6a7e8 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.1" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/docs/src/api.md b/docs/src/api.md index 83a0a5b..4016aa4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -9,4 +9,5 @@ glorot_normal glorot_uniform kaiming_normal kaiming_uniform +truncated_normal ``` diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index f226909..89bdb1c 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,5 +1,6 @@ module WeightInitializers using Random +using SpecialFunctions using Statistics include("inits.jl") @@ -7,5 +8,6 @@ include("inits.jl") export zeros32, ones32, rand32, randn32 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform +export truncated_normal end diff --git a/src/inits.jl b/src/inits.jl index 15d490b..e703184 100644 --- a/src/inits.jl +++ b/src/inits.jl @@ -191,3 +191,41 @@ function kaiming_normal(; init_kwargs...) init_kwargs..., kwargs...) end + +""" + truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) + +Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. +The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. +""" +function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo=-2, hi=2) + norm_cdf(x) = 0.5 * (1 + erf(x / √2)) + if (mean < lo - 2 * std) || (mean > hi + 2 * std) + @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 + end + l = norm_cdf((lo - mean) / std) + u = norm_cdf((hi - mean) / std) + xs = rand(rng, Float32, dims...) + broadcast!(xs, xs) do x + x = x * 2(u - l) + (2l - 1) + x = erfinv(x) + return x = clamp(x * std * √2 + mean, lo, hi) + end + return xs +end + +function truncated_normal(dims::Integer...; kwargs...) + return truncated_normal(_default_rng(), dims...; kwargs...) +end +function truncated_normal(rng::AbstractRNG; init_kwargs...) + return (rng, dims...; kwargs...) -> truncated_normal(rng, + dims...; + init_kwargs..., + kwargs...) +end +function truncated_normal(; init_kwargs...) + return (dims...; kwargs...) -> truncated_normal(_default_rng(), + dims...; + init_kwargs..., + kwargs...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 4ee5462..c496840 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ const rng = StableRNG(12345) kaiming_normal, glorot_uniform, glorot_normal, + truncated_normal, ] # Sizes @test size(init(3)) == (3,) @@ -31,6 +32,7 @@ end kaiming_normal, glorot_uniform, glorot_normal, + truncated_normal, ] cl = init(rng) # Sizes From ff85d65ee237ecfec05305c22fa83b892f901883 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 19 Jun 2023 22:24:23 +0200 Subject: [PATCH 09/11] added PartialFunctions, some tests --- Project.toml | 3 +- src/WeightInitializers.jl | 2 ++ src/inits.jl | 67 +++++++-------------------------------- test/runtests.jl | 17 ++++++++-- 4 files changed, 29 insertions(+), 60 deletions(-) diff --git a/Project.toml b/Project.toml index cd6a7e8..6bffc6f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.1" +version = "0.1.0" [deps] +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index 89bdb1c..fb56218 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,4 +1,6 @@ module WeightInitializers + +using PartialFunctions using Random using SpecialFunctions using Statistics diff --git a/src/inits.jl b/src/inits.jl index e703184..f826fec 100644 --- a/src/inits.jl +++ b/src/inits.jl @@ -3,6 +3,7 @@ @inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices @inline _nfan(dims::Tuple) = _nfan(dims...) @inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels +norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) function _default_rng() @static if VERSION >= v"1.7" @@ -37,8 +38,6 @@ given `size`. """ randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) randn32(dims...) = randn32(_default_rng(), dims...) -randn32(rng::AbstractRNG) = (rng, dims...) -> randn32(rng, dims...) -randn32() = (dims...,) -> randn32(_default_rng(), dims...) """ rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) @@ -48,8 +47,6 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the """ rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) rand32(dims...) = rand32(_default_rng(), dims...) -rand32(rng::AbstractRNG) = (rng, dims...) -> rand32(rng, dims...) -rand32() = (dims...,) -> rand32(_default_rng(), dims...) """ glorot_uniform(rng::AbstractRNG, size...; gain = 1) @@ -74,18 +71,8 @@ function glorot_uniform(dims::Integer...; kwargs...) return glorot_uniform(_default_rng(), dims...; kwargs...) end -function glorot_uniform(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> glorot_uniform(rng, - dims...; - init_kwargs..., - kwargs...) -end - -function glorot_uniform(; init_kwargs...) - return (dims...; kwargs...) -> glorot_uniform(_default_rng(), - dims...; - init_kwargs..., - kwargs...) +function glorot_uniform(; kwargs...) + return glorot_uniform $ (; kwargs...) end """ @@ -110,19 +97,10 @@ function glorot_normal(dims::Integer...; kwargs...) return glorot_normal(_default_rng(), dims...; kwargs...) end -function glorot_normal(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> glorot_normal(rng, - dims...; - init_kwargs..., - kwargs...) +function glorot_normal(rng::AbstractRNG; kwargs...) + return glorot_normal $ (; kwargs...) end -function glorot_normal(; init_kwargs...) - return (dims...; kwargs...) -> glorot_normal(_default_rng(), - dims...; - init_kwargs..., - kwargs...) -end """ kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) @@ -144,19 +122,10 @@ function kaiming_uniform(dims::Integer...; kwargs...) return kaiming_uniform(_default_rng(), dims...; kwargs...) end -function kaiming_uniform(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> kaiming_uniform(rng, - dims...; - init_kwargs..., - kwargs...) +function kaiming_uniform(rng::AbstractRNG; kwargs...) + return kaiming_uniform $ (; kwargs...) end -function kaiming_uniform(; init_kwargs...) - return (dims...; kwargs...) -> kaiming_uniform(_default_rng(), - dims...; - init_kwargs..., - kwargs...) -end """ kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) @@ -178,18 +147,8 @@ function kaiming_normal(dims::Integer...; kwargs...) return kaiming_normal(_default_rng(), dims...; kwargs...) end -function kaiming_normal(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> kaiming_normal(rng, - dims...; - init_kwargs..., - kwargs...) -end - -function kaiming_normal(; init_kwargs...) - return (dims...; kwargs...) -> kaiming_normal(_default_rng(), - dims...; - init_kwargs..., - kwargs...) +function kaiming_normal(rng::AbstractRNG; kwargs...) + return kaiming_normal $ (; kwargs...) end """ @@ -199,7 +158,6 @@ Return an `Array{Float32}` of the given `size` where each element is drawn from The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. """ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo=-2, hi=2) - norm_cdf(x) = 0.5 * (1 + erf(x / √2)) if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 end @@ -223,9 +181,6 @@ function truncated_normal(rng::AbstractRNG; init_kwargs...) init_kwargs..., kwargs...) end -function truncated_normal(; init_kwargs...) - return (dims...; kwargs...) -> truncated_normal(_default_rng(), - dims...; - init_kwargs..., - kwargs...) +function truncated_normal(; kwargs...) + return truncated_normal $ (; kwargs...) end diff --git a/test/runtests.jl b/test/runtests.jl index c496840..4be6ccb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,19 @@ using WeightInitializers, Test, SafeTestsets, StableRNGs, Statistics const rng = StableRNG(12345) +@testset "_nfan" begin + # Fallback + @test WeightInitializers._nfan() == (1, 1) + # Vector + @test WeightInitializers._nfan(4) == (1, 4) + # Matrix + @test WeightInitializers._nfan(4, 5) == (5, 4) + # Tuple + @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) + # Convolution + @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) +end + @testset "Sizes and Types: $init" for init in [ zeros32, ones32, @@ -26,15 +39,13 @@ const rng = StableRNG(12345) end @testset "Closure: $init" for init in [ - rand32, - randn32, kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, ] - cl = init(rng) + cl = init(;) # Sizes @test size(cl(3)) == (3,) @test size(cl(rng, 3)) == (3,) From c1bc4813e61325dbc23537da0114131161e24d4d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Jun 2023 10:04:26 -0400 Subject: [PATCH 10/11] Minor restructuring --- src/WeightInitializers.jl | 8 +++----- src/{inits.jl => initializers.jl} | 26 +++++--------------------- src/utils.jl | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 26 deletions(-) rename src/{inits.jl => initializers.jl} (87%) create mode 100644 src/utils.jl diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index fb56218..6d70386 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,11 +1,9 @@ module WeightInitializers -using PartialFunctions -using Random -using SpecialFunctions -using Statistics +using PartialFunctions, Random, SpecialFunctions, Statistics -include("inits.jl") +include("utils.jl") +include("initializers.jl") export zeros32, ones32, rand32, randn32 export glorot_normal, glorot_uniform diff --git a/src/inits.jl b/src/initializers.jl similarity index 87% rename from src/inits.jl rename to src/initializers.jl index f826fec..3f15ce0 100644 --- a/src/inits.jl +++ b/src/initializers.jl @@ -1,34 +1,18 @@ -@inline _nfan() = 1, 1 # fan_in, fan_out -@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix -@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices -@inline _nfan(dims::Tuple) = _nfan(dims...) -@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels -norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) - -function _default_rng() - @static if VERSION >= v"1.7" - return Xoshiro(1234) - else - return MersenneTwister(1234) - end -end - """ zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) """ -zeros32(rng::AbstractRNG, dims...) = zeros(rng, Float32, dims...) +zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...) zeros32(dims...) = zeros32(_default_rng(), dims...) -Base.zeros(rng::AbstractRNG, dims...) = zeros(dims...) + """ ones32(rng::AbstractRNG, size...) = ones(Float32, size...) Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) """ -ones32(rng::AbstractRNG, dims...) = ones(rng, Float32, dims...) +ones32(::AbstractRNG, dims...) = ones(Float32, dims...) ones32(dims...) = ones32(_default_rng(), dims...) -Base.ones(rng::AbstractRNG, dims...) = ones(dims...) """ randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) @@ -161,8 +145,8 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo= if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 end - l = norm_cdf((lo - mean) / std) - u = norm_cdf((hi - mean) / std) + l = _norm_cdf((lo - mean) / std) + u = _norm_cdf((hi - mean) / std) xs = rand(rng, Float32, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - 1) diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..325dcac --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,14 @@ +@inline _nfan() = 1, 1 # fan_in, fan_out +@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix +@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices +@inline _nfan(dims::Tuple) = _nfan(dims...) +@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels +_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) + +function _default_rng() + @static if VERSION >= v"1.7" + return Xoshiro(1234) + else + return MersenneTwister(1234) + end +end From a22bdf3ce653b2a8faa68f2f2208bfb509f8ed83 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Jun 2023 10:22:03 -0400 Subject: [PATCH 11/11] Cleanup the codebase using MetaProgramming --- README.md | 6 +- docs/mkdocs.yml | 1 + src/initializers.jl | 68 ++++++------------- src/utils.jl | 3 + test/runtests.jl | 156 ++++++++++++++++++++++---------------------- 5 files changed, 106 insertions(+), 128 deletions(-) diff --git a/README.md b/README.md index 9f7762c..56db605 100644 --- a/README.md +++ b/README.md @@ -58,18 +58,20 @@ weights = weights_cl(2, 5) The package is meant to be working with deep learning libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array. + ```julia weights = init(rng, dims...) ``` The `rng` is optional, if not specified a default one will be used. + ```julia weights = init(dims...) ``` If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) -and the keywords to get in return a function behaving like the -two examples above. +and the keywords to get in return a function behaving like the two examples above. + ```julia weights_init = init(rng; kwargs...) weights = weights_init(rng, dims...) diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 2ad45a6..77b6ad3 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -87,3 +87,4 @@ plugins: nav: - "WeightInitializers.jl": "index.md" + - "API Reference": "api.md" diff --git a/src/initializers.jl b/src/initializers.jl index 3f15ce0..b05c38c 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -1,18 +1,16 @@ """ - zeros32(rng::AbstractRNG, size...) = zeros(Float32, size...) + zeros32(::AbstractRNG, size...) = zeros(Float32, size...) Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored) """ zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...) -zeros32(dims...) = zeros32(_default_rng(), dims...) """ - ones32(rng::AbstractRNG, size...) = ones(Float32, size...) + ones32(::AbstractRNG, size...) = ones(Float32, size...) Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored) """ ones32(::AbstractRNG, dims...) = ones(Float32, dims...) -ones32(dims...) = ones32(_default_rng(), dims...) """ randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...) @@ -21,7 +19,6 @@ Return an `Array{Float32}` of random numbers from a standard normal distribution given `size`. """ randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) -randn32(dims...) = randn32(_default_rng(), dims...) """ rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...) @@ -30,7 +27,6 @@ Return an `Array{Float32}` of random numbers from a uniform distribution of the `size`. """ rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...) -rand32(dims...) = rand32(_default_rng(), dims...) """ glorot_uniform(rng::AbstractRNG, size...; gain = 1) @@ -51,14 +47,6 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) return (rand(rng, Float32, dims...) .- 0.5f0) .* scale end -function glorot_uniform(dims::Integer...; kwargs...) - return glorot_uniform(_default_rng(), dims...; kwargs...) -end - -function glorot_uniform(; kwargs...) - return glorot_uniform $ (; kwargs...) -end - """ glorot_normal(rng::AbstractRNG, size...; gain = 1) @@ -77,14 +65,6 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) return randn(rng, Float32, dims...) .* std end -function glorot_normal(dims::Integer...; kwargs...) - return glorot_normal(_default_rng(), dims...; kwargs...) -end - -function glorot_normal(rng::AbstractRNG; kwargs...) - return glorot_normal $ (; kwargs...) -end - """ kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0) @@ -102,14 +82,6 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0 return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound end -function kaiming_uniform(dims::Integer...; kwargs...) - return kaiming_uniform(_default_rng(), dims...; kwargs...) -end - -function kaiming_uniform(rng::AbstractRNG; kwargs...) - return kaiming_uniform $ (; kwargs...) -end - """ kaiming_normal(rng::AbstractRNG, size...; gain = √2f0) @@ -127,14 +99,6 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0) return randn(rng, Float32, dims...) .* std end -function kaiming_normal(dims::Integer...; kwargs...) - return kaiming_normal(_default_rng(), dims...; kwargs...) -end - -function kaiming_normal(rng::AbstractRNG; kwargs...) - return kaiming_normal $ (; kwargs...) -end - """ truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) @@ -156,15 +120,21 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo= return xs end -function truncated_normal(dims::Integer...; kwargs...) - return truncated_normal(_default_rng(), dims...; kwargs...) -end -function truncated_normal(rng::AbstractRNG; init_kwargs...) - return (rng, dims...; kwargs...) -> truncated_normal(rng, - dims...; - init_kwargs..., - kwargs...) -end -function truncated_normal(; kwargs...) - return truncated_normal $ (; kwargs...) +# Default Fallbacks for all functions +for initializer in (:zeros32, + :ones32, + :randn32, + :rand32, + :glorot_uniform, + :glorot_normal, + :kaiming_uniform, + :kaiming_normal, + :truncated_normal) + @eval function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), dims...; kwargs...) + end + @eval function ($initializer)(rng::AbstractRNG; kwargs...) + return _partial_apply($initializer, (rng, (; kwargs...))) + end + @eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...)) end diff --git a/src/utils.jl b/src/utils.jl index 325dcac..b26253e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,3 +12,6 @@ function _default_rng() return MersenneTwister(1234) end end + +# This is needed if using `PartialFunctions.$` inside @eval block +_partial_apply(fn, inp) = fn$inp diff --git a/test/runtests.jl b/test/runtests.jl index 4be6ccb..7120d1e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,88 +2,90 @@ using WeightInitializers, Test, SafeTestsets, StableRNGs, Statistics const rng = StableRNG(12345) -@testset "_nfan" begin - # Fallback - @test WeightInitializers._nfan() == (1, 1) - # Vector - @test WeightInitializers._nfan(4) == (1, 4) - # Matrix - @test WeightInitializers._nfan(4, 5) == (5, 4) - # Tuple - @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) - # Convolution - @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) -end +@testset "WeightInitializers.jl Tests" begin + @testset "_nfan" begin + # Fallback + @test WeightInitializers._nfan() == (1, 1) + # Vector + @test WeightInitializers._nfan(4) == (1, 4) + # Matrix + @test WeightInitializers._nfan(4, 5) == (5, 4) + # Tuple + @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) + # Convolution + @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) + end -@testset "Sizes and Types: $init" for init in [ - zeros32, - ones32, - rand32, - randn32, - kaiming_uniform, - kaiming_normal, - glorot_uniform, - glorot_normal, - truncated_normal, -] - # Sizes - @test size(init(3)) == (3,) - @test size(init(rng, 3)) == (3,) - @test size(init(3, 4)) == (3, 4) - @test size(init(rng, 3, 4)) == (3, 4) - @test size(init(3, 4, 5)) == (3, 4, 5) - @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(init(rng, 4, 2)) == Float32 - @test eltype(init(4, 2)) == Float32 -end + @testset "Sizes and Types: $init" for init in [ + zeros32, + ones32, + rand32, + randn32, + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, + truncated_normal, + ] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + end -@testset "Closure: $init" for init in [ - kaiming_uniform, - kaiming_normal, - glorot_uniform, - glorot_normal, - truncated_normal, -] - cl = init(;) - # Sizes - @test size(cl(3)) == (3,) - @test size(cl(rng, 3)) == (3,) - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - @test size(cl(3, 4, 5)) == (3, 4, 5) - @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 -end + @testset "Closure: $init" for init in [ + kaiming_uniform, + kaiming_normal, + glorot_uniform, + glorot_normal, + truncated_normal, + ] + cl = init(;) + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end -@testset "kaiming" begin - # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] - # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) - for (n_in, n_out) in [(100, 100), (100, 400)] - v = kaiming_uniform(rng, n_in, n_out) - σ2 = sqrt(6 / n_out) - @test -1σ2 < minimum(v) < -0.9σ2 - @test 0.9σ2 < maximum(v) < 1σ2 + @testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 - v = kaiming_normal(rng, n_in, n_out) - σ2 = sqrt(2 / n_out) - @test 0.9σ2 < std(v) < 1.1σ2 + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + @test 0.9σ2 < std(v) < 1.1σ2 + end + # Type + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5)) == Float32 end - # - @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32 - @test eltype(kaiming_normal(rng, 3, 4; gain=1.5)) == Float32 -end -@testset "glorot: $init" for init in [glorot_uniform, glorot_normal] - # glorot_uniform and glorot_normal should both yield a kernel with - # variance ≈ 2/(fan_in + fan_out) - for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] - v = init(dims...) - fan_in, fan_out = WeightInitializers._nfan(dims...) - σ2 = 2 / (fan_in + fan_out) - @test 0.9σ2 < var(v) < 1.1σ2 + @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = WeightInitializers._nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 end - @test eltype(init(3, 4; gain=1.5)) == Float32 end