Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1 from LuxDL/fm/start
Browse files Browse the repository at this point in the history
[WIP] Adding initializers
  • Loading branch information
avik-pal authored Jun 21, 2023
2 parents 0ae62b7 + f92fba3 commit f85f994
Show file tree
Hide file tree
Showing 10 changed files with 402 additions and 6 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,11 @@ uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.0"

[deps]
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
julia = "1.6"
69 changes: 68 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,71 @@
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

`WeightInitializers.jl` provides common weight initialization schemes for deep learning models.
This package is a light dependency providing common weight initialization schemes for deep learning models.

## Example
These code snippets are just provided to give a high level overview
of the functionalities of the package.
Please refer to the [stable documentation](https://luxdl.github.io/WeightInitializers.jl/stable) for mode information
about the package. The
[under development documentation](https://luxdl.github.io/WeightInitializers.jl/dev)
provides information on features not yet released.

```julia
using WeightInitializers, Random

# Fixing rng
rng = Random.MersenneTwister(42)

# Explicit rng call
weights = kaiming_normal(rng, 2, 5)
#2×5 Matrix{Float32}:
# -0.351662 0.0171745 1.12442 -0.296372 -1.67094
# -0.281053 -0.18941 -0.724099 0.0987538 0.634549

# Default rng call
weights = kaiming_normal(2, 5)
#2×5 Matrix{Float32}:
# -0.227513 -0.265372 0.265788 1.29955 -0.192836
# 0.687611 0.454679 -0.433656 0.20548 0.292002

# Passing kwargs (if needed) with explicit rng call
weights_cl = kaiming_normal(rng; gain=1.0)
weights = weights_cl(rng, 2, 5)
#2×5 Matrix{Float32}:
# 0.484056 0.231723 0.164379 0.306147 0.18365
# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971

# Passing kwargs (if needed) with default rng call
weights_cl = kaiming_normal(; gain=1.0)
weights = weights_cl(2, 5)
#2×5 Matrix{Float32}:
# -0.160876 -0.187646 0.18794 0.918918 -0.136356
# 0.486214 0.321506 -0.306641 0.145296 0.206476
```

## API

The package is meant to be working with deep learning
libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array.

```julia
weights = init(rng, dims...)
```

The `rng` is optional, if not specified a default one will be used.

```julia
weights = init(dims...)
```

If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally)
and the keywords to get in return a function behaving like the two examples above.

```julia
weights_init = init(rng; kwargs...)
weights = weights_init(rng, dims...)
# or
weights_init = init(; kwargs...)
weights = weights_init(dims...)
```
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ plugins:

nav:
- "WeightInitializers.jl": "index.md"
- "API Reference": "api.md"
13 changes: 13 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Weight Initializers

```@docs
zeros32
ones32
rand32
randn32
glorot_normal
glorot_uniform
kaiming_normal
kaiming_uniform
truncated_normal
```
57 changes: 53 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,59 @@
CurrentModule = WeightInitializers
```

## API Reference
```julia
using WeightInitializers, Random

### Index
# Fixing rng
rng = Random.MersenneTwister(42)

```@index
Pages = ["index.md"]
# Explicit rng call
weights = kaiming_normal(rng, 2, 5)
#2×5 Matrix{Float32}:
# -0.351662 0.0171745 1.12442 -0.296372 -1.67094
# -0.281053 -0.18941 -0.724099 0.0987538 0.634549

# Default rng call
weights = kaiming_normal(2, 5)
#2×5 Matrix{Float32}:
# -0.227513 -0.265372 0.265788 1.29955 -0.192836
# 0.687611 0.454679 -0.433656 0.20548 0.292002

# Passing kwargs (if needed) with explicit rng call
weights_cl = kaiming_normal(rng; gain=1.0)
weights = weights_cl(rng, 2, 5)
#2×5 Matrix{Float32}:
# 0.484056 0.231723 0.164379 0.306147 0.18365
# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971

# Passing kwargs (if needed) with default rng call
weights_cl = kaiming_normal(; gain=1.0)
weights = weights_cl(2, 5)
#2×5 Matrix{Float32}:
# -0.160876 -0.187646 0.18794 0.918918 -0.136356
# 0.486214 0.321506 -0.306641 0.145296 0.206476
```

## Quick examples

The package is meant to be working with deep learning
libraries such as F/Lux. All the methods take as input the chosen `rng` type and the dimension for the array.
```julia
weights = init(rng, dims...)
```

The `rng` is optional, if not specified a default one will be used.
```julia
weights = init(dims...)
```

If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally)
and the keywords to get in return a function behaving like the
two examples above.
```julia
weights_init = init(rng; kwargs...)
weights = weights_init(rng, dims...)
# or
weights_init = init(; kwargs...)
weights = weights_init(dims...)
```
10 changes: 10 additions & 0 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
module WeightInitializers

using PartialFunctions, Random, SpecialFunctions, Statistics

include("utils.jl")
include("initializers.jl")

export zeros32, ones32, rand32, randn32
export glorot_normal, glorot_uniform
export kaiming_normal, kaiming_uniform
export truncated_normal

end
140 changes: 140 additions & 0 deletions src/initializers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
zeros32(::AbstractRNG, size...) = zeros(Float32, size...)
Return an `Array{Float32}` of zeros of the given `size`. (`rng` is ignored)
"""
zeros32(::AbstractRNG, dims...) = zeros(Float32, dims...)

"""
ones32(::AbstractRNG, size...) = ones(Float32, size...)
Return an `Array{Float32}` of ones of the given `size`. (`rng` is ignored)
"""
ones32(::AbstractRNG, dims...) = ones(Float32, dims...)

"""
randn32(rng::AbstractRNG, size...) = randn(rng, Float32, size...)
Return an `Array{Float32}` of random numbers from a standard normal distribution of the
given `size`.
"""
randn32(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...)

"""
rand32(rng::AbstractRNG, size...) = rand(rng, Float32, size...)
Return an `Array{Float32}` of random numbers from a uniform distribution of the given
`size`.
"""
rand32(rng::AbstractRNG, dims...) = rand(rng, Float32, dims...)

"""
glorot_uniform(rng::AbstractRNG, size...; gain = 1)
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a
uniform distribution on the interval ``[-x, x]``, where
`x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as
Xavier initialization.
# References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep
feedforward neural networks." _Proceedings of the thirteenth international conference on
artificial intelligence and statistics_. 2010.
"""
function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1)
scale = Float32(gain) * sqrt(24.0f0 / sum(_nfan(dims...)))
return (rand(rng, Float32, dims...) .- 0.5f0) .* scale
end

"""
glorot_normal(rng::AbstractRNG, size...; gain = 1)
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal
distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This method is
described in [1] and also known as Xavier initialization.
# References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep
feedforward neural networks." _Proceedings of the thirteenth international conference on
artificial intelligence and statistics_. 2010.
"""
function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
std = Float32(gain) * sqrt(2.0f0 / sum(_nfan(dims...)))
return randn(rng, Float32, dims...) .* std
end

"""
kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0)
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a
uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`.
# References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on
imagenet classification." _Proceedings of the IEEE international conference on computer
vision_. 2015.
"""
function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=2.0f0)
bound = Float32(3.0f0 * gain / sqrt(first(_nfan(dims...))))
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2 * bound
end

"""
kaiming_normal(rng::AbstractRNG, size...; gain = √2f0)
Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal
distribution standard deviation `gain / sqrt(fan_in)`
# References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on
imagenet classification." _Proceedings of the IEEE international conference on computer
vision_. 2015.
"""
function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=2.0f0)
std = Float32(gain / sqrt(first(_nfan(dims...))))
return randn(rng, Float32, dims...) .* std
end

"""
truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2)
Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution.
The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`.
"""
function truncated_normal(rng::AbstractRNG, dims::Integer...; mean=0, std=1, lo=-2, hi=2)
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1
end
l = _norm_cdf((lo - mean) / std)
u = _norm_cdf((hi - mean) / std)
xs = rand(rng, Float32, dims...)
broadcast!(xs, xs) do x
x = x * 2(u - l) + (2l - 1)
x = erfinv(x)
return x = clamp(x * std * 2 + mean, lo, hi)
end
return xs
end

# Default Fallbacks for all functions
for initializer in (:zeros32,
:ones32,
:randn32,
:rand32,
:glorot_uniform,
:glorot_normal,
:kaiming_uniform,
:kaiming_normal,
:truncated_normal)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return _partial_apply($initializer, (rng, (; kwargs...)))
end
@eval ($initializer)(; kwargs...) = _partial_apply($initializer, (; kwargs...))
end
17 changes: 17 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@inline _nfan() = 1, 1 # fan_in, fan_out
@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix
@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
@inline _nfan(dims::Tuple) = _nfan(dims...)
@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels
_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / 2))

function _default_rng()
@static if VERSION >= v"1.7"
return Xoshiro(1234)
else
return MersenneTwister(1234)
end
end

# This is needed if using `PartialFunctions.$` inside @eval block
_partial_apply(fn, inp) = fn$inp
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[deps]
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
Loading

2 comments on commit f85f994

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/86004

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" f85f994e1a9c2a6aeaa127285c7338e51d228835
git push origin v0.1.0

Please sign in to comment.